Skip to content

Commit

Permalink
Merge pull request #15 from allegro/add-support-for-gemma-model
Browse files Browse the repository at this point in the history
Add support for Gemma model
  • Loading branch information
riccardo-alle authored Mar 4, 2024
2 parents c1838f8 + 3e1b4fa commit 0d4c837
Show file tree
Hide file tree
Showing 16 changed files with 336 additions and 43 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,23 @@ Among the allms most notable features, you will find:

___

## Supported Models

| LLM Family | Hosting | Supported LLMs |
|-------------|---------------------|-----------------------------------------|
| GPT(s) | OpenAI endpoint | `gpt-3.5-turbo`, `gpt-4`, `gpt-4-turbo` |
| Google LLMs | VertexAI deployment | `text-bison@001`, `gemini-pro` |
| Llama2 | Azure deployment | `llama2-7b`, `llama2-13b`, `llama2-70b` |
| Mistral | Azure deployment | `Mistral-7b`, `Mixtral-7bx8` |
| Gemma | GCP deployment | `gemma` |

* Do you already have a subscription to a Cloud Provider for any the models above? Configure
the model using your credentials and start querying!
* Are you interested in knowing how to self-deploy open-source models in Azure and GCP?
Consult our [guide](https://allms.allegro.tech/usage/deploy_open_source_models/)

___

## Documentation

Full documentation available at **[allms.allegro.tech](https://allms.allegro.tech/)**
Expand Down
12 changes: 11 additions & 1 deletion allms/defaults/vertex_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,14 @@ class GeminiModelDefaults:
TEMPERATURE = 0.0
TOP_P = 0.95
TOP_K = 40
VERBOSE = True
VERBOSE = True


class GemmaModelDefaults:
GCP_MODEL_NAME = "gemma"
MODEL_TOTAL_MAX_TOKENS = 8192
MAX_OUTPUT_TOKENS = 1024
TEMPERATURE = 0.0
TOP_P = 0.95
TOP_K = 40
VERBOSE = True
7 changes: 6 additions & 1 deletion allms/domain/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,9 @@ class AzureSelfDeployedConfiguration:
@dataclass
class VertexAIConfiguration:
cloud_project: str
cloud_location: str
cloud_location: str


@dataclass
class VertexAIModelGardenConfiguration(VertexAIConfiguration):
endpoint_id: str
1 change: 1 addition & 0 deletions allms/domain/enumerables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class AvailableModels(str, ListConvertableEnum):
AZURE_MISTRAL_MODEL = "azure_mistral"
VERTEXAI_PALM2_MODEL = "vertexai_palm2"
VERTEXAI_GEMINI_MODEL = "vertexai_gemini"
VERTEXAI_GEMMA_MODEL = "vertexai_gemma"


class LanguageModelTask(str, ListConvertableEnum):
Expand Down
5 changes: 4 additions & 1 deletion allms/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from allms.models.azure_openai import AzureOpenAIModel
from allms.models.vertexai_gemini import VertexAIGeminiModel
from allms.models.vertexai_palm import VertexAIPalmModel
from allms.models.vertexai_gemma import VertexAIGemmaModel

__all__ = [
"AzureOpenAIModel",
"AzureLlama2Model",
"AzureMistralModel",
"VertexAIPalmModel",
"VertexAIGeminiModel",
"VertexAIGemmaModel",
"get_available_models"
]

Expand All @@ -24,6 +26,7 @@ def get_available_models() -> dict[str, Type[AbstractModel]]:
AvailableModels.AZURE_LLAMA2_MODEL: AzureLlama2Model,
AvailableModels.AZURE_MISTRAL_MODEL: AzureMistralModel,
AvailableModels.VERTEXAI_PALM2_MODEL: VertexAIPalmModel,
AvailableModels.VERTEXAI_GEMINI_MODEL: VertexAIGeminiModel
AvailableModels.VERTEXAI_GEMINI_MODEL: VertexAIGeminiModel,
AvailableModels.VERTEXAI_GEMMA_MODEL: VertexAIGemmaModel,
}

2 changes: 1 addition & 1 deletion allms/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
raise ValueError("max_output_tokens has to be lower than model_total_max_tokens")

self._llm = self._create_llm()
self._event_loop = event_loop if event_loop is not None else asyncio.new_event_loop()
self._event_loop = event_loop if event_loop is not None else asyncio.get_event_loop()

self._predict_example = create_base_retry_decorator(
error_types=[
Expand Down
66 changes: 63 additions & 3 deletions allms/models/vertexai_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import List, Optional, Any
from typing import List, Optional, Any, Dict

from langchain_community.llms.vertexai import VertexAI
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
from google.cloud.aiplatform.models import Prediction
from langchain_community.llms.vertexai import VertexAI, VertexAIModelGarden
from langchain_core.callbacks import AsyncCallbackManagerForLLMRun
from langchain_core.outputs import LLMResult, Generation
from pydash import chain

Expand Down Expand Up @@ -47,3 +48,62 @@ def was_response_blocked(generation: Generation) -> bool:
llm_output=result.llm_output,
run=result.run
)


class VertexAIModelGardenWrapper(VertexAIModelGarden):
temperature: float = 0.0
max_tokens: int = 128
top_p: float = 0.95
top_k: int = 40
n: int = 1

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.allowed_model_args = list(self._default_params.keys())

@property
def _default_params(self) -> Dict[str, Any]:
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_k": self.top_k,
"top_p": self.top_p,
"n": self.n
}

def _parse_response(self, predictions: "Prediction", prompts: List[str]) -> LLMResult:
generations: List[List[Generation]] = []
for result, prompt in zip(predictions.predictions, prompts):
if isinstance(result, str):
generations.append([Generation(text=self._parse_prediction(result, prompt))])
else:
generations.append(
[
Generation(text=self._parse_prediction(prediction, prompt))
for prediction in result
]
)
return LLMResult(generations=generations)

def _parse_prediction(self, prediction: Any, prompt: str) -> str:
parsed_prediction = super()._parse_prediction(prediction)
try:
text_to_remove = f"Prompt:\n{prompt}\nOutput:\n"
return parsed_prediction.rsplit(text_to_remove, maxsplit=1)[1]
except Exception:
raise ValueError(f"Output returned from the model doesn't follow the expected format.")

async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
kwargs = {**kwargs, **self._default_params}
instances = self._prepare_request(prompts, **kwargs)
response = await self.async_client.predict(
endpoint=self.endpoint_path, instances=instances
)
return self._parse_response(response, prompts)

52 changes: 52 additions & 0 deletions allms/models/vertexai_gemma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from asyncio import AbstractEventLoop

from langchain_community.llms.vertexai import VertexAIModelGarden
from typing import Optional

from allms.defaults.general_defaults import GeneralDefaults
from allms.defaults.vertex_ai import GemmaModelDefaults
from allms.domain.configuration import VertexAIModelGardenConfiguration
from allms.models.vertexai_base import VertexAIModelGardenWrapper
from allms.models.abstract import AbstractModel


class VertexAIGemmaModel(AbstractModel):
def __init__(
self,
config: VertexAIModelGardenConfiguration,
temperature: float = GemmaModelDefaults.TEMPERATURE,
top_k: int = GemmaModelDefaults.TOP_K,
top_p: float = GemmaModelDefaults.TOP_P,
max_output_tokens: int = GemmaModelDefaults.MAX_OUTPUT_TOKENS,
model_total_max_tokens: int = GemmaModelDefaults.MODEL_TOTAL_MAX_TOKENS,
max_concurrency: int = GeneralDefaults.MAX_CONCURRENCY,
max_retries: int = GeneralDefaults.MAX_RETRIES,
verbose: bool = GemmaModelDefaults.VERBOSE,
event_loop: Optional[AbstractEventLoop] = None
) -> None:
self._top_p = top_p
self._top_k = top_k
self._verbose = verbose
self._config = config

super().__init__(
temperature=temperature,
model_total_max_tokens=model_total_max_tokens,
max_output_tokens=max_output_tokens,
max_concurrency=max_concurrency,
max_retries=max_retries,
event_loop=event_loop
)

def _create_llm(self) -> VertexAIModelGarden:
return VertexAIModelGardenWrapper(
model_name=GemmaModelDefaults.GCP_MODEL_NAME,
max_tokens=self._max_output_tokens,
temperature=self._temperature,
top_p=self._top_p,
top_k=self._top_k,
verbose=self._verbose,
project=self._config.cloud_project,
location=self._config.cloud_location,
endpoint_id=self._config.endpoint_id
)
8 changes: 4 additions & 4 deletions docs/api/models/vertexai_gemini_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ __init__(
temperature: float = 0.0,
top_k: int = 40,
top_p: float = 0.95,
max_output_tokens: int = 1024,
model_total_max_tokens: int = 8192,
max_output_tokens: int = 2048,
model_total_max_tokens: int = 30720,
max_concurrency: int = 1000,
max_retries: int = 8,
verbose: bool = True
Expand All @@ -22,8 +22,8 @@ __init__(
- `top_p` (`float`): Top-p changes how the model selects tokens for output. Tokens are selected from most probable to
least until the sum of their probabilities equals the top_p value. Default: `0.95`.
- `max_output_tokens` (`int`): The maximum number of tokens to generate by the model. The total length of input tokens
and generated tokens is limited by the model's context length. Default: `1024`.
- `model_total_max_tokens` (`int`): Context length of the model - maximum number of input plus generated tokens. Default: `8192`.
and generated tokens is limited by the model's context length. Default: `2048`.
- `model_total_max_tokens` (`int`): Context length of the model - maximum number of input plus generated tokens. Default: `30720`.
- `max_concurrency` (`int`): Maximum number of concurrent requests. Default: `1000`.
- `max_retries` (`int`): Maximum number of retries if a request fails. Default: `8`.
- `verbose` (`bool`): Default: `True`.
Expand Down
85 changes: 85 additions & 0 deletions docs/api/models/vertexai_gemma.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
## `class allms.models.VertexAIGemmaModel` API
### Methods
```python
__init__(
config: VertexAIModelGardenConfiguration,
temperature: float = 0.0,
top_k: int = 40,
top_p: float = 0.95,
max_output_tokens: int = 1024,
model_total_max_tokens: int = 8192,
max_concurrency: int = 1000,
max_retries: int = 8,
verbose: bool = True
)
```
#### Parameters
- `config` (`VertexAIModelGardenConfiguration`): An instance of `VertexAIModelGardenConfiguration` class
- `temperature` (`float`): The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more
random, while lower values like 0.2 will make it more focused and deterministic. Default: `0.0`.
- `top_k` (`int`): Changes how the model selects tokens for output. A top-k of 3 means that the next token is selected
from among the 3 most probable tokens. Default: `40`.
- `top_p` (`float`): Top-p changes how the model selects tokens for output. Tokens are selected from most probable to
least until the sum of their probabilities equals the top_p value. Default: `0.95`.
- `max_output_tokens` (`int`): The maximum number of tokens to generate by the model. The total length of input tokens
and generated tokens is limited by the model's context length. Default: `1024`.
- `model_total_max_tokens` (`int`): Context length of the model - maximum number of input plus generated tokens. Default: `8192`.
- `max_concurrency` (`int`): Maximum number of concurrent requests. Default: `1000`.
- `max_retries` (`int`): Maximum number of retries if a request fails. Default: `8`.
- `verbose` (`bool`): Default: `True`.

---

```python
generate(
prompt: str,
system_prompt: Optional[str] = None,
input_data: typing.Optional[typing.List[InputData]] = None,
output_data_model_class: typing.Optional[typing.Type[BaseModel]] = None
) -> typing.List[ResponseData]:
```
#### Parameters
- `prompt` (`str`): Prompt to use to query the model.
- `system_prompt` (`Optional[str]`): System prompt that will be used by the model.
- `input_data` (`Optional[List[InputData]]`): If prompt contains symbolic variables you can use this parameter to
generate model responses for batch of examples. Each symbolic variable from the prompt should have mapping provided
in the `input_mappings` of `InputData`.
- `output_data_model_class` (`Optional[Type[BaseModel]]`): If provided forces the model to generate output in the
format defined by the passed class. Generated response is automatically parsed to this class.

#### Returns
`List[ResponseData]`: Each `ResponseData` contains the response for a single example from `input_data`. If `input_data`
is not provided, the length of this list is equal 1, and the first element is the response for the raw prompt.

---

## `class allms.domain.configuration.VertexAIModelGardenConfiguration` API
```python
VertexAIModelGardenConfiguration(
cloud_project: str,
cloud_location: str,
endpoint_id: str
)
```
#### Parameters
- `cloud_project` (`str`): The GCP project to use when making Vertex API calls.
- `cloud_location` (`str`): The region to use when making API calls.
- `endpoint_id` (`str`): ID of an endpoint where the model has been deployed.

---

### Example usage

```python
from allms.models import VertexAIGemmaModel
from allms.domain.configuration import VertexAIModelGardenConfiguration

configuration = VertexAIModelGardenConfiguration(
cloud_project="<GCP_PROJECT_ID>",
cloud_location="<MODEL_REGION>",
endpoint_id="<ENDPOINT_ID>"
)

vertex_model = VertexAIGemmaModel(config=configuration)
vertex_response = vertex_model.generate("2+2 is?")
```
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<p align="center">
<img src="assets/images/logo.png" alt="LLM-Wrapper Logo"/>
<img src="assets/images/logo.png" alt="aLLMs Logo"/>
</p>

# Introduction
Expand All @@ -24,5 +24,5 @@ Currently, the library supports:

* OpenAI models hosted on Microsoft Azure (`gpt-3.5-turbo`, `gpt4`, `gpt4-turbo`);
* Google Cloud Platform VertexAI models (`PaLM2`, `Gemini`);
* Open-source models `Llama2` and `Mistral` self-deployed on Azure.
* Open-source models `Llama2` and `Mistral` self-deployed on Azure and `Gemma` self-deployed on GCP

24 changes: 22 additions & 2 deletions docs/installation_and_quick_start.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ gpt_response = gpt_model.generate("2+2 is?")
* `<OPENAI_API_DEPLOYMENT_NAME>`: The name under which the model was deployed.
* `<OPENAI_API_MODEL_NAME>`: The underlying model's name.

### Google PaLM
### VertexAI PaLM

```python
from allms.models import VertexAIPalmModel
Expand All @@ -57,7 +57,7 @@ palm_response = palm_model.generate("2+2 is?")
* `<GCP_PROJECT_ID>`: The GCP project in which you have access to the PALM model.
* `<MODEL_REGION>`: The region where the model is deployed.

### Google Gemini
### VertexAI Gemini

```python
from allms.models import VertexAIGeminiModel
Expand All @@ -75,6 +75,26 @@ gemini_response = gemini_model.generate("2+2 is?")
* `<GCP_PROJECT_ID>`: The GCP project in which you have access to the PALM model.
* `<MODEL_REGION>`: The region where the model is deployed.

### VertexAI Gemma

```python
from allms.models import VertexAIGemmaModel
from allms.domain.configuration import VertexAIModelGardenConfiguration

configuration = VertexAIModelGardenConfiguration(
cloud_project="<GCP_PROJECT_ID>",
cloud_location="<MODEL_REGION>",
endpoint_id="<ENDPOINT_ID>"
)

gemini_model = VertexAIGemmaModel(config=configuration)
gemini_response = gemini_model.generate("2+2 is?")
```

* `<GCP_PROJECT_ID>`: The GCP project in which you have access to the PALM model.
* `<MODEL_REGION>`: The region where the model is deployed.
* `<ENDPOINT_ID>`: ID of an endpoint where the model has been deployed.

### Azure LLaMA 2

```python
Expand Down
Loading

0 comments on commit 0d4c837

Please sign in to comment.