Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/haystack integration #19

Merged
merged 5 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions lmformatenforcer/integrations/haystackv1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
try:
from haystack.nodes import PromptNode
except ImportError:
raise ImportError('haystack is not installed. Please install it with "pip install farm-haystack"')

import enum
from typing import Callable, Optional
from lmformatenforcer import CharacterLevelParser

class LMFormatEnforcerPromptNode(PromptNode):
"""A prompt node for Haystack V1 API that activates the LMFormatEnforcer on the generated text"""
class ModelType(enum.Enum):
HUGGINGFACE = 'HFLocalInvocationLayer'
# VLLM = 'vLLMLocalInvocationLayer' TODO: After vLLM 0.22 will be relased, this will be possible

def __init__(self, *args, character_level_parser: Optional[CharacterLevelParser] = None, **kwargs):
"""Create a new prompt node that activates the LMFormatEnforcer on the generated text. See PromptNode
documentation for all of the regular arguments.
:param character_level_parser: A CharacterLevelParser that will be used to enforce the format of the generated"""
super().__init__(*args, **kwargs)
self.character_level_parser = character_level_parser
self.model_type = self._resolve_model_type()
self.token_enforcer_fn = self._prepare_token_enforcer_fn()

def _prepare_token_enforcer_fn(self) -> Optional[Callable]:
if not self.character_level_parser:
return None
if self.model_type == LMFormatEnforcerPromptNode.ModelType.HUGGINGFACE:
tokenizer = self.prompt_model.model_invocation_layer.pipe.tokenizer
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
return build_transformers_prefix_allowed_tokens_fn(tokenizer, self.character_level_parser)
raise NotImplementedError(f"Token enforcer not implemented for model type {self.model_type.name}")

def _resolve_model_type(self) -> ModelType:
invocation_layer_name = self.prompt_model.model_invocation_layer.__class__.__name__
try:
return LMFormatEnforcerPromptNode.ModelType(invocation_layer_name)
except ValueError:
supported_strings = ",".join(str(t.name) for t in LMFormatEnforcerPromptNode.ModelType)
raise ValueError(f"Unsupported invocation layer: {invocation_layer_name}. "
f"Must be one of {supported_strings}")

def _prepare_model_kwargs(self):
model_kwargs = super()._prepare_model_kwargs()
if self.token_enforcer_fn:
if self.model_type == LMFormatEnforcerPromptNode.ModelType.HUGGINGFACE:
if 'generation_kwargs' not in model_kwargs:
model_kwargs['generation_kwargs'] = {}
model_kwargs['generation_kwargs']['prefix_allowed_tokens_fn'] = self.token_enforcer_fn
return model_kwargs
74 changes: 74 additions & 0 deletions lmformatenforcer/integrations/haystackv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
try:
from haystack.preview import component
from canals import Component
except ImportError:
raise ImportError('haystack is not installed. Please install it with "pip install farm-haystack" or "pip install haystack-ai')

import enum
from typing import Any, Callable, Dict, List, Optional
from lmformatenforcer import CharacterLevelParser


class _ModelType(enum.Enum):
HUGGINGFACE = 'HuggingFaceLocalGenerator'
# VLLM = 'vLLMLocalInvocationLayer' TODO: Add this when vLLM has Haystack V2 support

@component
class LMFormatEnforcerLocalGenerator:
"""A generator component for Haystack V2 API that activates the LMFormatEnforcer on the generated text.
It wraps a local generator, and should be added to the pipeline instead of it"""
def __init__(self, model_component: Component, character_level_parser: Optional[CharacterLevelParser] = None):
"""Initialize the generator component
:param model_component: A local generator component to wrap
:param character_level_parser: A CharacterLevelParser that will be used to enforce the format of the generated"""
self.model_component = model_component
self.character_level_parser = character_level_parser
self._model_type = self._resolve_model_type()
self.token_enforcer_fn: Optional[Callable] = None

@component.output_types(replies=List[str])
def run(self, prompt: str, generation_kwargs: Optional[Dict[str, Any]] = None):
try:
self._inject_enforcer_into_model()
kwargs = {}
if generation_kwargs:
kwargs['generation_kwargs'] = generation_kwargs
return self.model_component.run(prompt, **kwargs)
finally:
self._release_model_injection()

def warm_up(self):
if hasattr(self.model_component, 'warm_up'):
self.model_component.warm_up()
self.token_enforcer_fn = self._prepare_token_enforcer_fn()

def _prepare_token_enforcer_fn(self) -> Optional[Callable]:
if not self.character_level_parser:
return None
if self._model_type == _ModelType.HUGGINGFACE:
tokenizer = self.model_component.pipeline.tokenizer
from lmformatenforcer.integrations.transformers import build_transformers_prefix_allowed_tokens_fn
return build_transformers_prefix_allowed_tokens_fn(tokenizer, self.character_level_parser)
raise NotImplementedError(f"Token enforcer not implemented for model type {self._model_type.name}")

def _resolve_model_type(self) -> _ModelType:
generator_component_name = self.model_component.__class__.__name__
try:
return _ModelType(generator_component_name)
except ValueError:
supported_strings = ",".join(str(t.name) for t in _ModelType)
raise ValueError(f"Unsupported local generator component layer: {generator_component_name}. "
f"Must be one of {supported_strings}")

def _inject_enforcer_into_model(self):
if not self.token_enforcer_fn:
return
if self._model_type == _ModelType.HUGGINGFACE:
self.model_component.generation_kwargs['prefix_allowed_tokens_fn'] = self.token_enforcer_fn


def _release_model_injection(self):
if not self.token_enforcer_fn:
return
if self._model_type == _ModelType.HUGGINGFACE:
del self.model_component.generation_kwargs['prefix_allowed_tokens_fn']
Loading