Skip to content

Commit

Permalink
Chains mistral example: fix dep and improve output (basetenlabs#1316)
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten authored Jan 16, 2025
1 parent d2aebd8 commit 8aadbbf
Showing 1 changed file with 27 additions and 21 deletions.
48 changes: 27 additions & 21 deletions truss-chains/examples/mistral/mistral_chainlet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Protocol

from truss.base import truss_config
Expand All @@ -10,38 +11,38 @@
class MistralLLM(chains.ChainletBase):
remote_config = chains.RemoteConfig(
docker_image=chains.DockerImage(
# The mistral model needs some extra python packages.
pip_requirements=[
"transformers==4.38.1",
"transformers",
"torch==2.0.1",
"sentencepiece",
"accelerate",
"tokenizers",
]
),
compute=chains.Compute(cpu_count=2, gpu="A10G"),
# Cache the model weights in the image and make the huggingface
# access token secret available to the model.
assets=chains.Assets(
# Cache the model weights in the image and make the
cached=[
truss_config.ModelRepo(
repo_id=MISTRAL_HF_MODEL,
allow_patterns=["*.json", "*.safetensors", ".model"],
)
],
# Make huggingface access token secret available to the model.
secret_keys=["hf_access_token"],
),
)

def __init__(
self,
# Adding the `context` to the init arguments, allows us to access the
# huggingface token.
# Using the optional `context` init-argument, allows to access secrets
# such as the huggingface token.
context: chains.DeploymentContext = chains.depends_context(),
) -> None:
# Note the imports of the *specific* python requirements are pushed down to
# here. This code will only be executed on the remotely deployed chainlet,
# not in the local environment, so we don't need to install these packages
# in the local dev environment.
# Note: the imports of the *Chainlet-specific* python requirements are pushed
# down to here (so you don't need them locally or in the other Chainlet).
# The code here is only executed on the remotely deployed chainlet, where
# these dependencies are included in the docker image.
import torch
import transformers

Expand All @@ -57,7 +58,6 @@ def __init__(
torch_dtype=torch.float16,
use_auth_token=context.secrets["hf_access_token"],
)

self._generate_args = {
"max_new_tokens": 512,
"temperature": 1.0,
Expand All @@ -74,14 +74,18 @@ def __init__(
async def run_remote(self, data: str) -> str:
import torch

formatted_prompt = f"[INST] {data} [/INST]"
input_ids = self._tokenizer(
formatted_prompt, return_tensors="pt"
).input_ids.cuda()
prompt = f"[INST] {data} [/INST]"
input_ids = self._tokenizer(prompt, return_tensors="pt").input_ids.cuda()
with torch.no_grad():
output = self._model.generate(inputs=input_ids, **self._generate_args)
result = self._tokenizer.decode(output[0])
return result
return (
result.replace(prompt, "")
.replace("<s>", "")
.replace("</s>", "")
.replace("\n", " ")
.strip()
)


class MistralP(Protocol):
Expand All @@ -93,12 +97,14 @@ class PoemGenerator(chains.ChainletBase):
def __init__(self, mistral_llm: MistralP = chains.depends(MistralLLM)) -> None:
self._mistral_llm = mistral_llm

async def run_remote(self, words: list[str]) -> list[str]:
async def run_remote(self, words: list[str]) -> dict[str, str]:
tasks = []
for word in words:
prompt = f"Generate a poem about: {word}"
prompt = f"Write a really short poem about: {word}"
tasks.append(asyncio.ensure_future(self._mistral_llm.run_remote(prompt)))
return list(await asyncio.gather(*tasks))

poems = list(await asyncio.gather(*tasks))
return {word: poem for word, poem in zip(words, poems)}


if __name__ == "__main__":
Expand All @@ -110,7 +116,7 @@ async def run_remote(self, data: str) -> str:

with chains.run_local():
poem_generator = PoemGenerator(mistral_llm=FakeMistralLLM())
poems = asyncio.get_event_loop().run_until_complete(
results = asyncio.get_event_loop().run_until_complete(
poem_generator.run_remote(words=["apple", "banana"])
)
print(poems)
print(results)

0 comments on commit 8aadbbf

Please sign in to comment.