Skip to content

Commit

Permalink
Merge pull request #98 from danmcp/removefastchatdep
Browse files Browse the repository at this point in the history
Remove fastchat dependency
  • Loading branch information
mergify[bot] authored Sep 23, 2024
2 parents 893b6ec + ca129ab commit 53d6abf
Show file tree
Hide file tree
Showing 5 changed files with 380 additions and 11 deletions.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
FastChat
GitPython>=3.1.42,<4.0.0
shortuuid
openai>=1.13.3,<2.0.0
Expand Down
5 changes: 2 additions & 3 deletions src/instructlab/eval/mt_bench_answers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import time

# Third Party
# TODO need to look into this dependency
from fastchat.model.model_adapter import get_conversation_template # type: ignore
import shortuuid
import tqdm

Expand All @@ -20,6 +18,7 @@
load_questions,
temperature_config,
)
from .mt_bench_model_adapter import get_conversation_template # type: ignore

logger = setup_logger(__name__)

Expand Down Expand Up @@ -61,7 +60,7 @@ def get_answer(

choices = []
for i in range(num_choices):
conv = get_conversation_template(model)
conv = get_conversation_template(model, "granite")

turns = []
for j in range(len(question["turns"])):
Expand Down
12 changes: 5 additions & 7 deletions src/instructlab/eval/mt_bench_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
import time

# Third Party
from fastchat import conversation
from fastchat.model.model_adapter import get_conversation_template # type: ignore
import openai

# First Party
from instructlab.eval import exceptions

# Local
from .logger_config import setup_logger
from .mt_bench_conversation import Conversation
from .mt_bench_model_adapter import get_conversation_template

logger = setup_logger(__name__)

Expand Down Expand Up @@ -158,7 +158,7 @@ def run_judge_single(
rating = -1

system_prompt = judge.prompt_template["system_prompt"]
conv = get_conversation_template(model)
conv = get_conversation_template(model, "mixtral")
conv.set_system_message(system_prompt)
conv.append_message(conv.roles[0], user_prompt)
conv.append_message(conv.roles[1], None)
Expand Down Expand Up @@ -268,9 +268,7 @@ class Message(TypedDict):
role: str


def _get_messages(
conv: conversation.Conversation, merge_system_user_message: bool
) -> list[Message]:
def _get_messages(conv: Conversation, merge_system_user_message: bool) -> list[Message]:
messages = conv.to_openai_api_messages()
if (
(merge_system_user_message or conv.name == "mistral")
Expand All @@ -285,7 +283,7 @@ def _get_messages(
def chat_completion_openai(

Check warning on line 283 in src/instructlab/eval/mt_bench_common.py

View workflow job for this annotation

GitHub Actions / pylint

R0917: Too many positional arguments (6/5) (too-many-positional-arguments)
openai_client,
model,
conv: conversation.Conversation,
conv: Conversation,
temperature,
max_tokens,
merge_system_user_message: bool = False,
Expand Down
213 changes: 213 additions & 0 deletions src/instructlab/eval/mt_bench_conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# SPDX-License-Identifier: Apache-2.0
"""
Conversation prompt templates.
"""

# Standard
from enum import IntEnum, auto
from typing import Dict, List, Tuple, Union
import dataclasses


class SeparatorStyle(IntEnum):
"""Separator styles."""

ADD_COLON_SINGLE = auto()
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
LLAMA2 = auto()
DEFAULT = auto()


@dataclasses.dataclass
class Conversation:
# pylint: disable=too-many-instance-attributes
"""A class that manages prompt templates and keeps all conversation history."""

# The name of this template
name: str
# The template of the system prompt
system_template: str = "{system_message}"
# The system message
system_message: str = ""
# The names of two roles
roles: Tuple[str, str] = ("USER", "ASSISTANT")
# All messages. Each item is (role, message).
# Each message is either a string or a tuple of (string, List[image_url]).
messages: List[List[str | None]] = dataclasses.field(default_factory=list)
# The number of few shot examples
offset: int = 0
# The separator style and configurations
sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE
sep: str | None = "\n"
sep2: str | None = None
# Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] | None = None
# Stops generation if meeting any token in this list
stop_token_ids: List[int] | None = None

def set_system_message(self, system_message: str):
"""Set the system message."""
self.system_message = system_message

def get_system_message(self):
"""return the system message."""
return self.system_message

def append_message(self, role: str, message: str | None):
"""Append a new message."""
self.messages.append([role, message])

def update_last_message(self, message: str):
"""Update the last output.
The last message is typically set to be None when constructing the prompt,
so we need to update it in-place after getting the response from a model.
"""
self.messages[-1][1] = message

def to_openai_api_messages(self):
"""Convert the conversation to OpenAI chat completion format."""
if self.system_message == "":
ret = []
else:
ret = [{"role": "system", "content": self.system_message}]

for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({"role": "user", "content": msg})
else:
if msg is not None:
ret.append({"role": "assistant", "content": msg})
return ret

def copy(self):
return Conversation(
name=self.name,
system_template=self.system_template,
system_message=self.system_message,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
)

def dict(self):
return {
"template_name": self.name,
"system_message": self.system_message,
"roles": self.roles,
"messages": self.extract_text_from_messages(),
"offset": self.offset,
}


# A global registry for all conversation templates
conv_templates: Dict[str, Conversation] = {}


def register_conv_template(template: Conversation, override: bool = False):
"""Register a new conversation template."""
if not override:
assert (
template.name not in conv_templates
), f"{template.name} has been registered."

conv_templates[template.name] = template


def get_conv_template(name: str) -> Conversation:
"""Get a conversation template."""
return conv_templates[name].copy()


# An empty template for raw conversation.
register_conv_template(
Conversation(
name="raw",
system_message="",
roles=("", ""),
sep_style=SeparatorStyle.NO_COLON_SINGLE,
sep="",
)
)


# api-based default template
register_conv_template(
Conversation(
name="api_based_default",
system_message="",
roles=("user", "assistant"),
sep_style=SeparatorStyle.DEFAULT,
sep=None,
)
)


# ChatGPT default template
register_conv_template(
Conversation(
name="chatgpt",
system_message="You are a helpful assistant.",
roles=("user", "assistant"),
sep_style=SeparatorStyle.DEFAULT,
sep=None,
)
)

# Mistral template
# source: https://docs.mistral.ai/llm/mistral-instruct-v0.1#chat-template
register_conv_template(
Conversation(
name="mistral",
system_template="[INST] {system_message}\n",
roles=("[INST]", "[/INST]"),
sep_style=SeparatorStyle.LLAMA2,
sep=" ",
sep2="</s>",
)
)

register_conv_template(
Conversation(
name="labrador-chat",
system_template="<|system|>\n{system_message}",
system_message="""You are Labrador, an AI language model developed by IBM DMF (Data Model Factory) Alignment Team. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior. You always respond to greetings (for example, hi, hello, g'day, morning, afternoon, evening, night, what's up, nice to meet you, sup, etc) with "Hello! I am Labrador, created by the IBM DMF Alignment Team. How can I help you today?". Please do not say anything else and do not start a conversation.""",
roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
sep="\n",
stop_str="<|endoftext|>",
)
)

register_conv_template(
Conversation(
name="ibm-generic",
system_template="<|system|>\n{system_message}",
system_message="""You are an AI language model developed by IBM Research. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.""",
roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
sep="\n",
stop_str="<|endoftext|>",
)
)

register_conv_template(
Conversation(
name="granite-chat",
system_template="<|system|>\n{system_message}",
system_message="""You are Granite Chat, an AI language model developed by IBM. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.""",
roles=("<|user|>", "<|assistant|>"),
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
sep="\n",
stop_str="<|endoftext|>",
)
)
Loading

0 comments on commit 53d6abf

Please sign in to comment.