From caea13f4d89b994f9b827b2f9acb24501fdea0e5 Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 11 Oct 2024 16:26:41 +0200 Subject: [PATCH 01/35] Update GenerateOutput type --- src/distilabel/llms/typing.py | 11 +++++++++-- src/distilabel/steps/typing.py | 20 +++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/src/distilabel/llms/typing.py b/src/distilabel/llms/typing.py index a19d30cb00..a4c3e13068 100644 --- a/src/distilabel/llms/typing.py +++ b/src/distilabel/llms/typing.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, List, TypeVar, Union +from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, TypeVar, Union + +LLMOutput = List[Union[str, None]] +LLMStatistics = Dict[str, Any] + + +class GenerateOutput(TypedDict): + generations: LLMOutput + statistics: LLMStatistics -GenerateOutput = List[Union[str, None]] if TYPE_CHECKING: from numpy import floating diff --git a/src/distilabel/steps/typing.py b/src/distilabel/steps/typing.py index 720037a74f..47ffd424d7 100644 --- a/src/distilabel/steps/typing.py +++ b/src/distilabel/steps/typing.py @@ -15,7 +15,25 @@ from typing import Any, Dict, Iterator, List, Tuple, Union StepOutput = Iterator[List[Dict[str, Any]]] -"""`StepOutput` is an alias of the typing `Iterator[List[Dict[str, Any]]]`""" +# NOTE: The next iteration should be somthing like this: +# StepData = List[Dict[str, Any]] +# StepStatistics = Dict[str, Any] +# StepOutput = Iterator[Dict[str, Union[StepData, StepStatistics]]] +# """`StepOutput` is an alias of the typing. +# A step output is a dict of the form: +# { +# "outputs": [ +# {"col1": "val1", "col2": "val2"}, +# {"col1": "val1", "col2": "val2"}, +# {"col1": "val1", "col2": "val2"}, +# ], +# "statistics": { +# "llm": {}, +# "time": 12341234, +# ... +# } +# } +# """ GeneratorStepOutput = Iterator[Tuple[List[Dict[str, Any]], bool]] """`GeneratorStepOutput` is an alias of the typing `Iterator[Tuple[List[Dict[str, Any]], bool]]`""" From f7d7a0efbc854da875a999dac26d25c56c90bb5f Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 11 Oct 2024 16:27:09 +0200 Subject: [PATCH 02/35] Add draft function to compute number of tokens given a tokenizer --- src/distilabel/llms/statistics.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 src/distilabel/llms/statistics.py diff --git a/src/distilabel/llms/statistics.py b/src/distilabel/llms/statistics.py new file mode 100644 index 0000000000..4dd8714edd --- /dev/null +++ b/src/distilabel/llms/statistics.py @@ -0,0 +1,19 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, List + + +def compute_tokens(text: str, tokenizer: Callable[[str], List[int]]) -> int: + return len(tokenizer.encode(text)) if text else 0 From dbcfafa740e80e0a9bb0ded2aa431f447b30bf52 Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 11 Oct 2024 16:27:31 +0200 Subject: [PATCH 03/35] Refactor llm generation to return generations and statistics --- .../llms/huggingface/transformers.py | 22 ++++++++++++++++++- src/distilabel/llms/llamacpp.py | 15 ++++++++++++- .../llms/huggingface/test_transformers.py | 7 ++++-- tests/unit/llms/test_llamacpp.py | 8 +++++-- 4 files changed, 46 insertions(+), 6 deletions(-) diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 27ab00e5b9..4d81991e83 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -20,6 +20,7 @@ from distilabel.llms.base import LLM from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin +from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput @@ -233,11 +234,30 @@ def generate( # type: ignore prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn, pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore ) - return [ + llm_output = [ [generation["generated_text"] for generation in output] for output in outputs ] + result = [] + for input, output in zip(inputs, llm_output): + result.append( + { + "generations": output, + "statistics": { + "input_tokens": [ + compute_tokens(row["content"], self._pipeline.tokenizer) + for row in input + ], + "output_tokens": [ + compute_tokens(row, self._pipeline.tokenizer) + for row in output + ], + }, + } + ) + return result + def get_last_hidden_states( self, inputs: List["StandardInput"] ) -> List["HiddenState"]: diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/llms/llamacpp.py index 9d158ea525..69518b924c 100644 --- a/src/distilabel/llms/llamacpp.py +++ b/src/distilabel/llms/llamacpp.py @@ -219,6 +219,8 @@ def generate( # type: ignore structured_output = self.structured_output outputs = [] + input_tokens = [] + output_tokens = [] for _ in range(num_generations): # NOTE(plaguss): There seems to be a bug in how the logits processor # is used. Basically it consumes the FSM internally, and it isn't reinitialized @@ -241,7 +243,18 @@ def generate( # type: ignore ) ) outputs.append(chat_completions["choices"][0]["message"]["content"]) - batch_outputs.append(outputs) + input_tokens.append(chat_completions["usage"]["prompt_tokens"]) + output_tokens.append(chat_completions["usage"]["completion_tokens"]) + batch_outputs.append( + { + "generations": outputs, + "statistics": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + } + ) + return batch_outputs def _prepare_structured_output( diff --git a/tests/unit/llms/huggingface/test_transformers.py b/tests/unit/llms/huggingface/test_transformers.py index 97214ef5fc..8c6e25e96f 100644 --- a/tests/unit/llms/huggingface/test_transformers.py +++ b/tests/unit/llms/huggingface/test_transformers.py @@ -53,9 +53,12 @@ def test_generate(self, transformers_llm: TransformersLLM) -> None: ], num_generations=3, ) - assert len(responses) == 2 - assert len(responses[0]) == 3 + generations = responses[0]["generations"] + statistics = responses[0]["statistics"] + assert len(generations) == 3 + assert "input_tokens" in statistics + assert "output_tokens" in statistics def test_get_last_hidden_states(self, transformers_llm: TransformersLLM) -> None: inputs = [ diff --git a/tests/unit/llms/test_llamacpp.py b/tests/unit/llms/test_llamacpp.py index 35c611722d..e3dec0f04c 100644 --- a/tests/unit/llms/test_llamacpp.py +++ b/tests/unit/llms/test_llamacpp.py @@ -54,9 +54,13 @@ def test_generate(self, llm: LlamaCppLLM) -> None: ], num_generations=3, ) - + print("RESPONSE", responses) assert len(responses) == 2 - assert len(responses[0]) == 3 + generations = responses[0]["generations"] + statistics = responses[0]["statistics"] + assert len(generations) == 3 + assert "input_tokens" in statistics + assert "output_tokens" in statistics @pytest.mark.parametrize( "structured_output, dump", From a0bf204773cd7be7a09d414f14106cf03b514272 Mon Sep 17 00:00:00 2001 From: plaguss Date: Mon, 14 Oct 2024 07:06:42 +0200 Subject: [PATCH 04/35] Move statistics from the LLM to distilabel_metadata row --- src/distilabel/steps/tasks/base.py | 44 +++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index 0524749e26..d27a3b80f5 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -22,6 +22,7 @@ from distilabel.constants import DISTILABEL_METADATA_KEY from distilabel.errors import DistilabelUserError from distilabel.llms.base import LLM +from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.base import ( GeneratorStep, @@ -33,7 +34,7 @@ from distilabel.utils.dicts import group_dicts if TYPE_CHECKING: - from distilabel.llms.typing import GenerateOutput + from distilabel.llms.typing import GenerateOutput, LLMOutput, LLMStatistics from distilabel.steps.tasks.typing import ChatType, FormattedInput from distilabel.steps.typing import StepOutput @@ -129,7 +130,7 @@ def impute_step_outputs( data = row.copy() for output in self.get_outputs().keys(): data[output] = None - data = self._maybe_add_raw_input_output( + data = self._create_metadata( data, None, None, @@ -173,20 +174,26 @@ def _format_outputs( formatted_outputs = [] for output, input in zip(outputs, inputs * len(outputs)): # type: ignore try: - formatted_output = self.format_output(output, input) - formatted_output = self._maybe_add_raw_input_output( + # Extract the generations, and move the statistics to the distilabel_metadata, + # to keep everything clean + output_generations: "LLMOutput" = output.get("generations", []) + formatted_output = self.format_output(output_generations, input) + formatted_output = self._create_metadata( formatted_output, - output, + output_generations, input, add_raw_output=self.add_raw_output, # type: ignore add_raw_input=self.add_raw_input, # type: ignore + statistics=output.get("statistics"), ) formatted_outputs.append(formatted_output) except Exception as e: self._logger.warning( # type: ignore f"Task '{self.name}' failed to format output: {e}. Saving raw response." # type: ignore ) - formatted_outputs.append(self._output_on_failure(output, input)) + formatted_outputs.append( + self._output_on_failure(output.get("generations", []), input) + ) return formatted_outputs def _output_on_failure( @@ -198,7 +205,7 @@ def _output_on_failure( # Create a dictionary with the outputs of the task (every output set to None) outputs = {output: None for output in self.outputs} outputs["model_name"] = self.llm.model_name # type: ignore - outputs = self._maybe_add_raw_input_output( + outputs = self._create_metadata( outputs, output, input, @@ -207,16 +214,29 @@ def _output_on_failure( ) return outputs - def _maybe_add_raw_input_output( + # TODO: Rename to _create_metadata + def _create_metadata( self, output: Dict[str, Any], - raw_output: Union[str, None], + raw_output: List[Union[str, None]], input: Union[str, None], add_raw_output: bool = True, add_raw_input: bool = True, - ): + statistics: Optional["LLMStatistics"] = None, + ) -> Dict[str, Any]: """Adds the raw output and or the formatted input of the LLM to the output dictionary if `add_raw_output` is True or `add_raw_input` is True. + + Args: + output: + The output dictionary after formatting the output from the LLM, + to add the raw output and or raw input. + raw_output: The raw output of the LLM (the list of generations). + input: The raw input of the LLM. + add_raw_output: Whether to add the raw output to the output dictionary. + add_raw_input: Whether to add the raw input to the output dictionary. + statistics: The statistics generated by the LLM, which should contain at least + the number of input and output tokens. """ meta = output.get(DISTILABEL_METADATA_KEY, {}) @@ -224,6 +244,8 @@ def _maybe_add_raw_input_output( meta[f"raw_output_{self.name}"] = raw_output if add_raw_input: meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None + if statistics: + meta["statistics"] = statistics if meta: output[DISTILABEL_METADATA_KEY] = meta @@ -406,6 +428,8 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore formatted_inputs = self._format_inputs(inputs) # `outputs` is a list containing a list of generations per input + # `outputs` is a dict containing the LLM outputs in the `generations` + # key and the statistics in the `statistics` key outputs = self.llm.generate_outputs( inputs=formatted_inputs, num_generations=self.num_generations, # type: ignore From a51ce593c2e7dec55c36fa3526f5b0331cef5283 Mon Sep 17 00:00:00 2001 From: plaguss Date: Mon, 14 Oct 2024 07:08:59 +0200 Subject: [PATCH 05/35] Update tests and LLM outputs to run with generations and statistics as the outputs --- tests/unit/conftest.py | 19 ++++++++++--- tests/unit/steps/tasks/test_base.py | 41 ++++++++++++++++++----------- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 0e2e157e65..905b0f7231 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -39,7 +39,11 @@ def model_name(self) -> str: async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": - return ["output" for _ in range(num_generations)] + # return ["output" for _ in range(num_generations)] + return [ + {"generations": "output", "statistics": {"test": "test"}} + for _ in range(num_generations) + ] class DummyLLM(LLM): @@ -55,7 +59,12 @@ def model_name(self) -> str: def generate( # type: ignore self, inputs: "FormattedInput", num_generations: int = 1 ) -> List["GenerateOutput"]: - return [["output" for _ in range(num_generations)]] + return [ + [ + {"generations": "output", "statistics": {"test": "test"}} + for _ in range(num_generations) + ] + ] class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): @@ -70,7 +79,11 @@ def generate( self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any ) -> List["GenerateOutput"]: return [ - ["Hello Magpie" for _ in range(num_generations)] for _ in range(len(inputs)) + [ + {"generations": "output", "statistics": {"test": "test"}} + for _ in range(num_generations) + ] + for _ in range(len(inputs)) ] diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index 29341052fb..ef400acfc9 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -109,6 +109,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, { @@ -123,6 +124,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, { @@ -137,6 +139,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, { @@ -151,6 +154,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, { @@ -165,6 +169,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, { @@ -179,6 +184,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, { @@ -193,6 +199,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, { @@ -207,6 +214,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, { @@ -221,6 +229,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], + "statistics": {"test": "test"}, }, }, ], @@ -256,6 +265,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, { "raw_output_task": "output", @@ -269,6 +279,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, { "raw_output_task": "output", @@ -282,10 +293,8 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, - # {"raw_output_task": "output"}, - # {"raw_output_task": "output"}, - # {"raw_output_task": "output"}, ], }, { @@ -311,6 +320,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, { "raw_output_task": "output", @@ -324,6 +334,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, { "raw_output_task": "output", @@ -337,6 +348,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, ], }, @@ -363,6 +375,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, { "raw_output_task": "output", @@ -376,6 +389,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, { "raw_output_task": "output", @@ -389,6 +403,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], + "statistics": {"test": "test"}, }, ], }, @@ -411,6 +426,7 @@ def test_process( group_generations=group_generations, num_generations=3, ) + task.load() result = next(task.process(input)) assert result == expected @@ -423,6 +439,7 @@ def test_process_overriding_inputs(self) -> None: num_generations=3, input_mappings={"instruction": "instruction_2"}, ) + task.load() result = next( task.process_applying_mappings( @@ -451,6 +468,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", + "statistics": {"test": "test"}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -472,6 +490,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", + "statistics": {"test": "test"}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -493,6 +512,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", + "statistics": {"test": "test"}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -676,15 +696,8 @@ def test_serialization(self) -> None: new_task = DummyTask.from_dict(task.dump()) assert isinstance(new_task, DummyTask) - @pytest.mark.parametrize( - "add_raw_output, add_raw_input", - [ - (True, False), - (False, True), - (True, True), - (False, False), - ], - ) + @pytest.mark.parametrize("add_raw_output", [True, False]) + @pytest.mark.parametrize("add_raw_input", [True, False]) def test_add_raw_input_and_or_output( self, add_raw_output: bool, add_raw_input: bool ) -> None: @@ -707,7 +720,6 @@ def test_add_raw_input_and_or_output( pprint.pprint(result) if add_raw_output or add_raw_input: - assert "distilabel_metadata" in result[0].keys() if add_raw_output: assert ( "raw_output_dummy_task_0" in result[0]["distilabel_metadata"].keys() @@ -716,5 +728,4 @@ def test_add_raw_input_and_or_output( assert ( "raw_input_dummy_task_0" in result[0]["distilabel_metadata"].keys() ) - else: - assert "distilabel_metadata" not in result[0].keys() + assert "statistics" in result[0]["distilabel_metadata"].keys() From fe5d4c526de77adbd1eed8ca291aed6090835d10 Mon Sep 17 00:00:00 2001 From: plaguss Date: Mon, 14 Oct 2024 10:59:04 +0200 Subject: [PATCH 06/35] Openai computed tokens --- .../llms/huggingface/transformers.py | 6 +- src/distilabel/llms/openai.py | 35 ++++++++++- src/distilabel/llms/statistics.py | 25 +++++++- tests/unit/llms/test_openai.py | 60 +++++++++++++++++-- 4 files changed, 113 insertions(+), 13 deletions(-) diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 4d81991e83..d2c7b83ed0 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -246,11 +246,13 @@ def generate( # type: ignore "generations": output, "statistics": { "input_tokens": [ - compute_tokens(row["content"], self._pipeline.tokenizer) + compute_tokens( + row["content"], self._pipeline.tokenizer.encode + ) for row in input ], "output_tokens": [ - compute_tokens(row, self._pipeline.tokenizer) + compute_tokens(row, self._pipeline.tokenizer.encode) for row in output ], }, diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index 48cac8a50e..b696763e31 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -22,6 +22,7 @@ from distilabel import envs from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException from distilabel.llms.base import AsyncLLM +from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType @@ -32,6 +33,7 @@ from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion from pydantic import BaseModel + from tiktoken.core import Encoding _OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" @@ -168,6 +170,7 @@ class User(BaseModel): _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME) _client: "OpenAI" = PrivateAttr(None) _aclient: "AsyncOpenAI" = PrivateAttr(None) + _tokenizer: "Encoding" = PrivateAttr(None) def load(self) -> None: """Loads the `AsyncOpenAI` client to benefit from async requests.""" @@ -210,6 +213,10 @@ def load(self) -> None: self._aclient = result.get("client") # type: ignore if structured_output := result.get("structured_output"): self.structured_output = structured_output + # It must be version 0.8.0 at least. + import tiktoken + + self._tokenizer = tiktoken.encoding_for_model(self.model) def unload(self) -> None: """Set clients to `None` as they both contain `thread._RLock` which cannot be pickled @@ -307,9 +314,20 @@ async def agenerate( # type: ignore kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore - if structured_output: - return self._generations_from_structured_output(completion) + # Note: Instructor extracts the content from the structured output, so we need to + # add the token count + generation = self._generations_from_structured_output(completion) + + return { + "generations": generation, + "statistics": { + "input_tokens": compute_tokens(input, self._tokenizer.encode), + "output_tokens": compute_tokens( + orjson.dumps(generation).decode("utf-8"), self._tokenizer.encode + ), + }, + } return self._generations_from_openai_completion(completion) @@ -346,7 +364,18 @@ def _generations_from_openai_completion( f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + + return { + "generations": generations, + "statistics": { + "input_tokens": completion.usage.prompt_tokens + if completion.usage + else 0, + "output_tokens": completion.usage.completion_tokens + if completion.usage + else 0, + }, + } def offline_batch_generate( self, diff --git a/src/distilabel/llms/statistics.py b/src/distilabel/llms/statistics.py index 4dd8714edd..8af0094b19 100644 --- a/src/distilabel/llms/statistics.py +++ b/src/distilabel/llms/statistics.py @@ -12,8 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List +from typing import Callable, List, Union +from distilabel.steps.tasks.typing import ChatType -def compute_tokens(text: str, tokenizer: Callable[[str], List[int]]) -> int: - return len(tokenizer.encode(text)) if text else 0 + +def compute_tokens( + text_or_messages: Union[str, ChatType], tokenizer: Callable[[str], List[int]] +) -> int: + """Helper function to count the number of tokens in a text or list of messages. + + Args: + text_or_messages: Either a string response or a list of messages. + tokenizer: A callable function that take str and returns the tokenized version of the text. + + Returns: + int: _description_ + """ + if isinstance(text_or_messages, str): + text = text_or_messages + else: + # If it's a list of messages, concatenate the content of each message + text = " ".join([message["content"] for message in text_or_messages]) + + return len(tokenizer(text)) if text else 0 diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py index 03fb94c1d3..c0c7fb4270 100644 --- a/tests/unit/llms/test_openai.py +++ b/tests/unit/llms/test_openai.py @@ -65,11 +65,14 @@ async def test_agenerate( llm._aclient = async_openai_mock mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -78,6 +81,10 @@ async def test_agenerate( }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 100, "output_tokens": 100}, + } @pytest.mark.asyncio async def test_agenerate_structured( @@ -93,6 +100,9 @@ async def test_agenerate_structured( }, ) # type: ignore llm._aclient = async_openai_mock + import tiktoken + + llm._tokenizer = tiktoken.encoding_for_model(self.model_id) sample_user = DummyUserDetail(name="John Doe", age=30) @@ -107,7 +117,10 @@ async def test_agenerate_structured( }, ] ) - assert generation[0] == sample_user.model_dump_json() + assert isinstance(generation, dict) + generations = generation["generations"] + assert generations[0] == sample_user.model_dump_json() + assert generation["statistics"] == {"input_tokens": 10, "output_tokens": 12} @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" @@ -206,6 +219,11 @@ def test_check_and_get_batch_results( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, }, }, @@ -228,6 +246,11 @@ def test_check_and_get_batch_results( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, }, }, @@ -236,7 +259,23 @@ def test_check_and_get_batch_results( llm.load() outputs = llm._check_and_get_batch_results() - assert outputs == [["output 1"], ["output 2"]] + + assert outputs == [ + { + "generations": ["output 1"], + "statistics": { + "input_tokens": 100, + "output_tokens": 100, + }, + }, + { + "generations": ["output 2"], + "statistics": { + "input_tokens": 100, + "output_tokens": 100, + }, + }, + ] def test_check_and_get_batch_results_raises_valueerror( self, _async_openai_mock: MagicMock, _openai_mock: MagicMock @@ -322,12 +361,23 @@ def test_parse_output( }, } ], + "usage": { + "prompt_tokens": 100, + "completion_tokens": 100, + "total_tokens": 200, + }, }, } } ) - assert result == [" Aenean hendrerit aliquam velit. ..."] + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": { + "input_tokens": 100, + "output_tokens": 100, + }, + } def test_retrieve_batch_results( self, _async_openai_mock: MagicMock, openai_mock: MagicMock From 394984f5dd9c360912a766d09bff0ab1f5a7eb6d Mon Sep 17 00:00:00 2001 From: plaguss Date: Mon, 14 Oct 2024 16:33:39 +0200 Subject: [PATCH 07/35] First version of async llms with statistics --- src/distilabel/llms/anthropic.py | 39 +++++++++++++++++++++----- src/distilabel/llms/base.py | 4 +-- src/distilabel/llms/cohere.py | 46 +++++++++++++++++++++++++++---- src/distilabel/llms/groq.py | 22 +++++++++++++-- tests/unit/llms/test_anthropic.py | 19 ++++++++++--- tests/unit/llms/test_cohere.py | 38 +++++++++++++------------ tests/unit/llms/test_groq.py | 25 +++++++++++++---- tests/unit/llms/test_openai.py | 14 ++++++++-- 8 files changed, 161 insertions(+), 46 deletions(-) diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/llms/anthropic.py index f938da58d2..0d5c047bae 100644 --- a/src/distilabel/llms/anthropic.py +++ b/src/distilabel/llms/anthropic.py @@ -24,10 +24,12 @@ get_type_hints, ) +import orjson from httpx import AsyncClient from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.llms.base import AsyncLLM +from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( @@ -36,7 +38,11 @@ ) if TYPE_CHECKING: + from typing import BaseModel + from anthropic import AsyncAnthropic + from anthropic.types import Message + from tokenizers import Tokenizer _ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY" @@ -142,6 +148,7 @@ class User(BaseModel): _api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME) _aclient: Optional["AsyncAnthropic"] = PrivateAttr(...) + _tokenizer: "Tokenizer" = PrivateAttr(...) def _check_model_exists(self) -> None: """Checks if the specified model exists in the available models.""" @@ -198,6 +205,10 @@ def load(self) -> None: if structured_output := result.get("structured_output"): self.structured_output = structured_output + from anthropic._tokenizers import sync_get_tokenizer + + self._tokenizer = sync_get_tokenizer() + @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -260,17 +271,31 @@ async def agenerate( # type: ignore if structured_output: kwargs = self._prepare_kwargs(kwargs, structured_output) - generations = [] - - completion = await self._aclient.messages.create(**kwargs) # type: ignore + completion: Union["Message", "BaseModel"] = await self._aclient.messages.create( + **kwargs + ) # type: ignore if structured_output: - generations.append(completion.model_dump_json()) - return generations + str_response = completion.model_dump_json() + return { + "generations": str_response, + "statistics": { + "input_tokens": compute_tokens(input, self._tokenizer.encode), + "output_tokens": compute_tokens( + orjson.dumps(str_response).decode("utf-8"), + self._tokenizer.encode, + ), + }, + } if (content := completion.content[0].text) is None: self._logger.warning( f"Received no response using Anthropic client (model: '{self.model}')." f" Finish reason was: {completion.stop_reason}" ) - generations.append(content) - return generations + return { + "generations": content, + "statistics": { + "input_tokens": completion.usage.input_tokens, + "output_tokens": completion.usage.output_tokens, + }, + } diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index ced6a8e041..170fcc3654 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -466,9 +466,9 @@ async def _agenerate( for input in inputs for _ in range(num_generations) ] - outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)] + outputs = await asyncio.gather(*tasks) return [ - list(group) + list(group)[0] for group in grouper(outputs, n=num_generations, incomplete="ignore") ] diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index e9d0d0c0f2..8c5c30d8cb 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -23,9 +23,12 @@ Union, ) +import orjson from pydantic import Field, PrivateAttr, SecretStr, validate_call +from tokenizers import Tokenizer from distilabel.llms.base import AsyncLLM +from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( @@ -34,7 +37,8 @@ ) if TYPE_CHECKING: - from cohere import AsyncClient, ChatMessage + from cohere import AsyncClient, ChatMessage, NonStreamedChatResponse + from pydantic import BaseModel _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" @@ -135,6 +139,7 @@ class User(BaseModel): _ChatMessage: Type["ChatMessage"] = PrivateAttr(...) _aclient: "AsyncClient" = PrivateAttr(...) + _tokenizer: "Tokenizer" = PrivateAttr(...) @property def model_name(self) -> str: @@ -172,6 +177,10 @@ def load(self) -> None: if structured_output := result.get("structured_output"): self.structured_output = structured_output + from cohere.manually_maintained.tokenizers import get_hf_tokenizer + + self._tokenizer: "Tokenizer" = get_hf_tokenizer(self._aclient, self.model) + def _format_chat_to_cohere( self, input: "FormattedInput" ) -> Tuple[Union[str, None], List["ChatMessage"], str]: @@ -278,16 +287,41 @@ async def agenerate( # type: ignore if structured_output: kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore - response = await self._aclient.chat(**kwargs) # type: ignore + response: Union[ + "NonStreamedChatResponse", "BaseModel" + ] = await self._aclient.chat(**kwargs) # type: ignore if structured_output: - return [response.model_dump_json()] + # TODO: Refactor the dict response, it's quite similar in many LLMs + str_response = response.model_dump_json() + return { + "generations": str_response, + "statistics": { + "input_tokens": compute_tokens(input, self._tokenizer.encode), + "output_tokens": compute_tokens( + orjson.dumps(str_response).decode("utf-8"), + self._tokenizer.encode, + ), + }, + } if (text := response.text) == "": self._logger.warning( # type: ignore f"Received no response using Cohere client (model: '{self.model}')." f" Finish reason was: {response.finish_reason}" ) - return [None] - - return [text] + return { + "generations": None, + "statistics": { + "input_tokens": compute_tokens(input, self._tokenizer.encode), + "output_tokens": 0, + }, + } + + return { + "generations": text, + "statistics": { + "input_tokens": compute_tokens(input, self._tokenizer.encode), + "output_tokens": compute_tokens(text, self._tokenizer.encode), + }, + } diff --git a/src/distilabel/llms/groq.py b/src/distilabel/llms/groq.py index c4c2554329..f159ff9b70 100644 --- a/src/distilabel/llms/groq.py +++ b/src/distilabel/llms/groq.py @@ -229,7 +229,14 @@ async def agenerate( # type: ignore completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore if structured_output: generations.append(completion.model_dump_json()) - return generations + return { + "generations": generations, + "statistics": { + # TODO: Need a way of knowing the tokenizer. + "input_tokens": 0, + "output_tokens": 0, + }, + } for choice in completion.choices: if (content := choice.message.content) is None: @@ -238,4 +245,15 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + + return { + "generations": generations, + "statistics": { + "input_tokens": completion.usage.prompt_tokens + if completion.usage + else 0, + "output_tokens": completion.usage.completion_tokens + if completion.usage + else 0, + }, + } diff --git a/tests/unit/llms/test_anthropic.py b/tests/unit/llms/test_anthropic.py index 11fee764c3..3597b8c354 100644 --- a/tests/unit/llms/test_anthropic.py +++ b/tests/unit/llms/test_anthropic.py @@ -37,12 +37,14 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None: llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key") # type: ignore llm._aclient = mock_anthropic - mocked_completion = Mock() - mocked_completion.content = [Mock(text="Aenean hendrerit aliquam velit...")] + mocked_completion = Mock( + content=[Mock(text="Aenean hendrerit aliquam velit...")], + usage=Mock(input_tokens=100, output_tokens=100), + ) llm._aclient.messages.create = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -51,6 +53,10 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None: }, ] ) + assert result == { + "generations": "Aenean hendrerit aliquam velit...", + "statistics": {"input_tokens": 100, "output_tokens": 100}, + } @pytest.mark.asyncio async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: @@ -64,6 +70,9 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: }, ) # type: ignore llm._aclient = mock_openai + from anthropic._tokenizers import sync_get_tokenizer + + llm._tokenizer = sync_get_tokenizer() sample_user = DummyUserDetail(name="John Doe", age=30) @@ -78,7 +87,9 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: }, ] ) - assert generation[0] == sample_user.model_dump_json() + generations = generation["generations"] + assert generations == sample_user.model_dump_json() + assert generation["statistics"] == {"input_tokens": 20, "output_tokens": 11} @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/llms/test_cohere.py index 2e398e01cf..c7f1e2c553 100644 --- a/tests/unit/llms/test_cohere.py +++ b/tests/unit/llms/test_cohere.py @@ -19,6 +19,7 @@ import nest_asyncio import pytest +from tokenizers import Tokenizer from distilabel.llms.cohere import CohereLLM @@ -50,16 +51,12 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None: llm = CohereLLM(model="command-r") llm._aclient = mock_async_client # type: ignore - mocked_completion = mock.Mock( - choices=[ - mock.Mock( - message=mock.Mock(content=" Aenean hendrerit aliquam velit. ...") - ) - ] - ) + mocked_completion = mock.Mock(text="Aenean hendrerit aliquam velit...") llm._aclient.chat = mock.AsyncMock(return_value=mocked_completion) - await llm.agenerate( + llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased") + + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -68,6 +65,10 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None: }, ] ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": 23, "output_tokens": 16}, + } @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" @@ -89,6 +90,7 @@ async def test_agenerate_structured( sample_user = DummyUserDetail(name="John Doe", age=30) llm._aclient.chat = mock.AsyncMock(return_value=sample_user) + llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased") generation = await llm.agenerate( input=[ @@ -99,25 +101,23 @@ async def test_agenerate_structured( }, ] ) - assert generation == [sample_user.model_dump_json()] + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": {"input_tokens": 23, "output_tokens": 26}, + } @pytest.mark.asyncio async def test_generate(self, mock_async_client: mock.MagicMock) -> None: llm = CohereLLM(model="command-r") llm._aclient = mock_async_client # type: ignore - mocked_completion = mock.Mock( - choices=[ - mock.Mock( - message=mock.Mock(content=" Aenean hendrerit aliquam velit. ...") - ) - ] - ) + mocked_completion = mock.Mock(text="Aenean hendrerit aliquam velit...") llm._aclient.chat = mock.AsyncMock(return_value=mocked_completion) + llm._tokenizer = Tokenizer.from_pretrained("bert-base-uncased") nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -128,6 +128,10 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None: ] ] ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": 23, "output_tokens": 16}, + } @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py index f137750292..058512b0e7 100644 --- a/tests/unit/llms/test_groq.py +++ b/tests/unit/llms/test_groq.py @@ -38,7 +38,10 @@ async def test_agenerate(self, mock_groq: MagicMock) -> None: llm._aclient = mock_groq mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) @@ -50,7 +53,10 @@ async def test_agenerate(self, mock_groq: MagicMock) -> None: "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [" Aenean hendrerit aliquam velit. ..."] + ) == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 100, "output_tokens": 100}, + } @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" @@ -81,7 +87,10 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: }, ] ) - assert generation[0] == sample_user.model_dump_json() + assert generation == { + "generations": sample_user.model_dump_json(), + "statistics": {"input_tokens": 0, "output_tokens": 0}, + } @pytest.mark.asyncio async def test_generate(self, mock_groq: MagicMock) -> None: @@ -89,7 +98,8 @@ async def test_generate(self, mock_groq: MagicMock) -> None: llm._aclient = mock_groq mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[Mock(message=Mock(content="Aenean hendrerit aliquam velit..."))], + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) @@ -105,7 +115,12 @@ async def test_generate(self, mock_groq: MagicMock) -> None: }, ] ] - ) == [[" Aenean hendrerit aliquam velit. ..."]] + ) == [ + { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": 100, "output_tokens": 100}, + } + ] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py index c0c7fb4270..af743409fe 100644 --- a/tests/unit/llms/test_openai.py +++ b/tests/unit/llms/test_openai.py @@ -103,7 +103,6 @@ async def test_agenerate_structured( import tiktoken llm._tokenizer = tiktoken.encoding_for_model(self.model_id) - sample_user = DummyUserDetail(name="John Doe", age=30) llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) @@ -133,13 +132,16 @@ async def test_generate( llm._aclient = async_openai_mock mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -150,6 +152,12 @@ async def test_generate( ] ] ) + assert result == [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 100, "output_tokens": 100}, + } + ] with pytest.raises(ValueError): llm.generate( From 9003fa0194a7d253e87efa29fb04c22456edc6d0 Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 15 Oct 2024 16:09:50 +0200 Subject: [PATCH 08/35] Return generations with list of strings and token count from _raw_response --- src/distilabel/llms/anthropic.py | 21 ++------ src/distilabel/llms/cohere.py | 12 ++--- src/distilabel/llms/groq.py | 16 ++++--- src/distilabel/llms/litellm.py | 30 ++++++++++-- src/distilabel/llms/mistral.py | 18 +++++-- src/distilabel/llms/ollama.py | 12 ++++- src/distilabel/llms/openai.py | 38 +++------------ src/distilabel/llms/statistics.py | 10 ++-- src/distilabel/llms/vertexai.py | 12 ++++- .../huggingface/test_inference_endpoints.py | 1 - tests/unit/llms/test_anthropic.py | 37 +++++++++----- tests/unit/llms/test_cohere.py | 10 ++-- tests/unit/llms/test_groq.py | 12 +++-- tests/unit/llms/test_litellm.py | 14 +++++- tests/unit/llms/test_mistral.py | 48 ++++++++++++++----- tests/unit/llms/test_ollama.py | 22 +++++++-- tests/unit/llms/test_openai.py | 16 +++++-- tests/unit/llms/test_vertexai.py | 24 +++++++--- tests/unit/llms/utils.py | 9 +++- 19 files changed, 238 insertions(+), 124 deletions(-) diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/llms/anthropic.py index 0d5c047bae..5e19a75105 100644 --- a/src/distilabel/llms/anthropic.py +++ b/src/distilabel/llms/anthropic.py @@ -24,12 +24,10 @@ get_type_hints, ) -import orjson from httpx import AsyncClient from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.llms.base import AsyncLLM -from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( @@ -42,7 +40,6 @@ from anthropic import AsyncAnthropic from anthropic.types import Message - from tokenizers import Tokenizer _ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY" @@ -148,7 +145,6 @@ class User(BaseModel): _api_key_env_var: str = PrivateAttr(default=_ANTHROPIC_API_KEY_ENV_VAR_NAME) _aclient: Optional["AsyncAnthropic"] = PrivateAttr(...) - _tokenizer: "Tokenizer" = PrivateAttr(...) def _check_model_exists(self) -> None: """Checks if the specified model exists in the available models.""" @@ -205,10 +201,6 @@ def load(self) -> None: if structured_output := result.get("structured_output"): self.structured_output = structured_output - from anthropic._tokenizers import sync_get_tokenizer - - self._tokenizer = sync_get_tokenizer() - @property def model_name(self) -> str: """Returns the model name used for the LLM.""" @@ -275,15 +267,12 @@ async def agenerate( # type: ignore **kwargs ) # type: ignore if structured_output: - str_response = completion.model_dump_json() + raw_response = completion._raw_response return { - "generations": str_response, + "generations": [completion.model_dump_json()], "statistics": { - "input_tokens": compute_tokens(input, self._tokenizer.encode), - "output_tokens": compute_tokens( - orjson.dumps(str_response).decode("utf-8"), - self._tokenizer.encode, - ), + "input_tokens": raw_response.usage.input_tokens, + "output_tokens": raw_response.usage.output_tokens, }, } @@ -293,7 +282,7 @@ async def agenerate( # type: ignore f" Finish reason was: {completion.stop_reason}" ) return { - "generations": content, + "generations": [content], "statistics": { "input_tokens": completion.usage.input_tokens, "output_tokens": completion.usage.output_tokens, diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index 8c5c30d8cb..c4be6c8f9a 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -37,7 +37,7 @@ ) if TYPE_CHECKING: - from cohere import AsyncClient, ChatMessage, NonStreamedChatResponse + from cohere import AsyncClient, ChatMessage, Message from pydantic import BaseModel @@ -287,15 +287,13 @@ async def agenerate( # type: ignore if structured_output: kwargs = self._prepare_kwargs(kwargs, structured_output) # type: ignore - response: Union[ - "NonStreamedChatResponse", "BaseModel" - ] = await self._aclient.chat(**kwargs) # type: ignore + response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs) # type: ignore if structured_output: # TODO: Refactor the dict response, it's quite similar in many LLMs str_response = response.model_dump_json() return { - "generations": str_response, + "generations": [str_response], "statistics": { "input_tokens": compute_tokens(input, self._tokenizer.encode), "output_tokens": compute_tokens( @@ -311,7 +309,7 @@ async def agenerate( # type: ignore f" Finish reason was: {response.finish_reason}" ) return { - "generations": None, + "generations": [None], "statistics": { "input_tokens": compute_tokens(input, self._tokenizer.encode), "output_tokens": 0, @@ -319,7 +317,7 @@ async def agenerate( # type: ignore } return { - "generations": text, + "generations": [text], "statistics": { "input_tokens": compute_tokens(input, self._tokenizer.encode), "output_tokens": compute_tokens(text, self._tokenizer.encode), diff --git a/src/distilabel/llms/groq.py b/src/distilabel/llms/groq.py index f159ff9b70..2143a0ce0f 100644 --- a/src/distilabel/llms/groq.py +++ b/src/distilabel/llms/groq.py @@ -225,19 +225,21 @@ async def agenerate( # type: ignore if structured_output: kwargs = self._prepare_kwargs(kwargs, structured_output) - generations = [] completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore if structured_output: - generations.append(completion.model_dump_json()) + raw_response = completion._raw_response return { - "generations": generations, + "generations": [completion.model_dump_json()], "statistics": { - # TODO: Need a way of knowing the tokenizer. - "input_tokens": 0, - "output_tokens": 0, + "input_tokens": raw_response.usage.prompt_tokens + if raw_response.usage + else 0, + "output_tokens": raw_response.usage.completion_tokens + if raw_response.usage + else 0, }, } - + generations = [] for choice in completion.choices: if (content := choice.message.content) is None: self._logger.warning( # type: ignore diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/llms/litellm.py index 48361ef706..5c4646c6ef 100644 --- a/src/distilabel/llms/litellm.py +++ b/src/distilabel/llms/litellm.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Callable, List, Optional, Union +import orjson from pydantic import Field, PrivateAttr, validate_call from distilabel.llms.base import AsyncLLM @@ -194,6 +195,7 @@ async def agenerate( # type: ignore # noqa: C901 A list of lists of strings containing the generated responses for each input. """ import litellm + from litellm import token_counter structured_output = None if isinstance(input, tuple): @@ -256,10 +258,24 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: raise e generations = [] + input_tokens = token_counter(model=self.model, messages=input) + output_tokens = 0 if self.structured_output: - generations.append([choice.model_dump_json() for choice in choices]) - return generations + for choice in choices: + generations.append(choice.model_dump_json()) + output_tokens += token_counter( + model=self.model, + text=orjson.dumps(choice.model_dump_json()).decode("utf-8"), + ) + + return { + "generations": generations, + "statistics": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + } for choice in choices: if (content := choice.message.content) is None: @@ -268,4 +284,12 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + output_tokens += token_counter(model=self.model, text=content) + + return { + "generations": generations, + "statistics": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + } diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/llms/mistral.py index a913d6ad0a..92a88ce71b 100644 --- a/src/distilabel/llms/mistral.py +++ b/src/distilabel/llms/mistral.py @@ -221,8 +221,14 @@ async def agenerate( # type: ignore completion = await self._aclient.chat.complete_async(**kwargs) # type: ignore if structured_output: - generations.append(completion.model_dump_json()) - return generations + raw_response = completion._raw_response + return { + "generations": [completion.model_dump_json()], + "statistics": { + "input_tokens": raw_response.usage.prompt_tokens, + "output_tokens": raw_response.usage.completion_tokens, + }, + } for choice in completion.choices: if (content := choice.message.content) is None: @@ -231,4 +237,10 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - return generations + return { + "generations": generations, + "statistics": { + "input_tokens": completion.usage.prompt_tokens, + "output_tokens": completion.usage.completion_tokens, + }, + } diff --git a/src/distilabel/llms/ollama.py b/src/distilabel/llms/ollama.py index fc3abd605b..37a808ed4f 100644 --- a/src/distilabel/llms/ollama.py +++ b/src/distilabel/llms/ollama.py @@ -159,6 +159,8 @@ async def agenerate( # type: ignore A list of strings as completion for the given input. """ text = None + input_tokens = 0 + output_tokens = 0 try: completion: Dict[str, Any] = await self._aclient.chat( # type: ignore model=self.model, @@ -169,10 +171,18 @@ async def agenerate( # type: ignore keep_alive=keep_alive, ) text = completion["message"]["content"] + input_tokens = completion["prompt_eval_count"] + output_tokens = completion["eval_count"] except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Ollama client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - return [text] + return { + "generations": [text], + "statistics": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + } diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index b696763e31..3c4a9688da 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -22,7 +22,6 @@ from distilabel import envs from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException from distilabel.llms.base import AsyncLLM -from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType @@ -32,8 +31,6 @@ from openai.types import Batch as OpenAIBatch from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion - from pydantic import BaseModel - from tiktoken.core import Encoding _OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" @@ -170,7 +167,6 @@ class User(BaseModel): _api_key_env_var: str = PrivateAttr(_OPENAI_API_KEY_ENV_VAR_NAME) _client: "OpenAI" = PrivateAttr(None) _aclient: "AsyncOpenAI" = PrivateAttr(None) - _tokenizer: "Encoding" = PrivateAttr(None) def load(self) -> None: """Loads the `AsyncOpenAI` client to benefit from async requests.""" @@ -213,10 +209,6 @@ def load(self) -> None: self._aclient = result.get("client") # type: ignore if structured_output := result.get("structured_output"): self.structured_output = structured_output - # It must be version 0.8.0 at least. - import tiktoken - - self._tokenizer = tiktoken.encoding_for_model(self.model) def unload(self) -> None: """Set clients to `None` as they both contain `thread._RLock` which cannot be pickled @@ -315,36 +307,20 @@ async def agenerate( # type: ignore completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore if structured_output: - # Note: Instructor extracts the content from the structured output, so we need to - # add the token count - generation = self._generations_from_structured_output(completion) - return { - "generations": generation, + "generations": [completion.model_dump_json()], "statistics": { - "input_tokens": compute_tokens(input, self._tokenizer.encode), - "output_tokens": compute_tokens( - orjson.dumps(generation).decode("utf-8"), self._tokenizer.encode - ), + "input_tokens": completion._raw_response.usage.prompt_tokens + if completion._raw_response + else 0, + "output_tokens": completion._raw_response.usage.completion_tokens + if completion._raw_response + else 0, }, } return self._generations_from_openai_completion(completion) - def _generations_from_structured_output( - self, completion: "BaseModel" - ) -> "GenerateOutput": - """Get the generations from the structured output object. - - Args: - completion: an instance of `pydantic.BaseModel` with the content of the structuted - output. - - Returns: - A list with the content of the structured output. - """ - return [completion.model_dump_json()] - def _generations_from_openai_completion( self, completion: "OpenAIChatCompletion" ) -> "GenerateOutput": diff --git a/src/distilabel/llms/statistics.py b/src/distilabel/llms/statistics.py index 8af0094b19..efbaa50716 100644 --- a/src/distilabel/llms/statistics.py +++ b/src/distilabel/llms/statistics.py @@ -27,12 +27,12 @@ def compute_tokens( tokenizer: A callable function that take str and returns the tokenized version of the text. Returns: - int: _description_ + The number of tokens. """ - if isinstance(text_or_messages, str): - text = text_or_messages - else: + if isinstance(text_or_messages, list): # If it's a list of messages, concatenate the content of each message text = " ".join([message["content"] for message in text_or_messages]) + else: + text = text_or_messages - return len(tokenizer(text)) if text else 0 + return len(tokenizer(text)) diff --git a/src/distilabel/llms/vertexai.py b/src/distilabel/llms/vertexai.py index 0c49fa3931..7e85fbb079 100644 --- a/src/distilabel/llms/vertexai.py +++ b/src/distilabel/llms/vertexai.py @@ -160,15 +160,25 @@ async def agenerate( # type: ignore ) text = None + input_tokens = 0 + output_tokens = 0 try: text = content.candidates[0].text + input_tokens = content.usage_metadata.prompt_token_count + output_tokens = content.usage_metadata.candidates_token_count except ValueError: self._logger.warning( # type: ignore f"Received no response using VertexAI client (model: '{self.model}')." f" Finish reason was: '{content.candidates[0].finish_reason}'." ) - return [text] + return { + "generations": [text], + "statistics": { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + }, + } def _is_gemini_model(model: str) -> bool: diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index d820122a4d..4205e1fd26 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -235,7 +235,6 @@ async def test_generate(self, mock_inference_client: MagicMock) -> None: ) nest_asyncio.apply() - assert llm.generate( inputs=[ [ diff --git a/tests/unit/llms/test_anthropic.py b/tests/unit/llms/test_anthropic.py index 3597b8c354..9d165280f8 100644 --- a/tests/unit/llms/test_anthropic.py +++ b/tests/unit/llms/test_anthropic.py @@ -54,7 +54,7 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None: ] ) assert result == { - "generations": "Aenean hendrerit aliquam velit...", + "generations": ["Aenean hendrerit aliquam velit..."], "statistics": {"input_tokens": 100, "output_tokens": 100}, } @@ -70,12 +70,13 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: }, ) # type: ignore llm._aclient = mock_openai - from anthropic._tokenizers import sync_get_tokenizer - - llm._tokenizer = sync_get_tokenizer() - - sample_user = DummyUserDetail(name="John Doe", age=30) + mocked_usage = MagicMock( + usage=MagicMock(input_tokens=100, output_tokens=100), + ) + sample_user = DummyUserDetail( + name="John Doe", age=30, _raw_response=mocked_usage + ) llm._aclient.messages.create = AsyncMock(return_value=sample_user) generation = await llm.agenerate( @@ -87,9 +88,13 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: }, ] ) - generations = generation["generations"] - assert generations == sample_user.model_dump_json() - assert generation["statistics"] == {"input_tokens": 20, "output_tokens": 11} + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": { + "input_tokens": 100, + "output_tokens": 100, + }, + } @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" @@ -99,14 +104,16 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None: llm = AnthropicLLM(model="claude-3-opus-20240229") # type: ignore llm._aclient = mock_anthropic - mocked_completion = Mock() - mocked_completion.content = [Mock(text="Aenean hendrerit aliquam velit...")] + mocked_completion = Mock( + content=[Mock(text="Aenean hendrerit aliquam velit...")], + usage=Mock(input_tokens=100, output_tokens=100), + ) llm._aclient.messages.create = AsyncMock(return_value=mocked_completion) nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -117,6 +124,12 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": 100, "output_tokens": 100}, + } + ] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/llms/test_cohere.py index c7f1e2c553..b88da01804 100644 --- a/tests/unit/llms/test_cohere.py +++ b/tests/unit/llms/test_cohere.py @@ -128,10 +128,12 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None: ] ] ) - assert result == { - "generations": ["Aenean hendrerit aliquam velit..."], - "statistics": {"input_tokens": 23, "output_tokens": 16}, - } + assert result == [ + { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": 23, "output_tokens": 16}, + } + ] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py index 058512b0e7..534214a5ef 100644 --- a/tests/unit/llms/test_groq.py +++ b/tests/unit/llms/test_groq.py @@ -74,8 +74,12 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: ) # type: ignore llm._aclient = mock_openai - sample_user = DummyUserDetail(name="John Doe", age=30) - + mocked_usage = MagicMock( + usage=MagicMock(prompt_tokens=100, completion_tokens=100), + ) + sample_user = DummyUserDetail( + name="John Doe", age=30, _raw_response=mocked_usage + ) llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) generation = await llm.agenerate( @@ -88,8 +92,8 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: ] ) assert generation == { - "generations": sample_user.model_dump_json(), - "statistics": {"input_tokens": 0, "output_tokens": 0}, + "generations": [sample_user.model_dump_json()], + "statistics": {"input_tokens": 100, "output_tokens": 100}, } @pytest.mark.asyncio diff --git a/tests/unit/llms/test_litellm.py b/tests/unit/llms/test_litellm.py index 56be99e028..d0a53f66ae 100644 --- a/tests/unit/llms/test_litellm.py +++ b/tests/unit/llms/test_litellm.py @@ -42,7 +42,7 @@ async def test_agenerate(self, mock_litellm: MagicMock, model: str) -> None: ) llm._aclient = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -51,6 +51,10 @@ async def test_agenerate(self, mock_litellm: MagicMock, model: str) -> None: }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 21, "output_tokens": 11}, + } @pytest.mark.asyncio async def test_generate(self, mock_litellm: MagicMock, model: str) -> None: @@ -64,7 +68,7 @@ async def test_generate(self, mock_litellm: MagicMock, model: str) -> None: nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -75,6 +79,12 @@ async def test_generate(self, mock_litellm: MagicMock, model: str) -> None: ] ] ) + assert result == [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 21, "output_tokens": 11}, + } + ] def test_serialization(self, _: MagicMock, model: str) -> None: llm = LiteLLM(model=model) # type: ignore diff --git a/tests/unit/llms/test_mistral.py b/tests/unit/llms/test_mistral.py index f1b7b4b28f..a613f9b49c 100644 --- a/tests/unit/llms/test_mistral.py +++ b/tests/unit/llms/test_mistral.py @@ -40,15 +40,18 @@ def test_mistral_llm(self, mock_mistral: MagicMock) -> None: @pytest.mark.asyncio async def test_agenerate(self, mock_mistral: MagicMock) -> None: - llm = MistralLLM(model="mistral-tiny", api_key="api.key") # type: ignore + llm = MistralLLM(model="mistral-small", api_key="api.key") # type: ignore llm._aclient = mock_mistral mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=10, completion_tokens=10, total_tokens=20), ) - llm._aclient.chat = AsyncMock(return_value=mocked_completion) + llm._aclient.chat.complete_async = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -57,11 +60,15 @@ async def test_agenerate(self, mock_mistral: MagicMock) -> None: }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 10, "output_tokens": 10}, + } @pytest.mark.asyncio async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: llm = MistralLLM( - model="mistral-tiny", + model="mistral-small", api_key="api.key", structured_output={ "schema": DummyUserDetail, @@ -71,12 +78,16 @@ async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: ) # type: ignore llm._aclient = mock_mistral - sample_user = DummyUserDetail(name="John Doe", age=30) - + mocked_usage = MagicMock( + usage=MagicMock(prompt_tokens=100, completion_tokens=100), + ) + sample_user = DummyUserDetail( + name="John Doe", age=30, _raw_response=mocked_usage + ) + # llm._aclient.chat.completions.create = AsyncMock(return_value=Mock(messages=sample_user)) llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) # This should work just with the _aclient.chat method once it's fixed in instructor, and # then in our code. - # llm._aclient.chat = AsyncMock(return_value=sample_user) generation = await llm.agenerate( input=[ @@ -87,7 +98,13 @@ async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: }, ] ) - assert generation[0] == sample_user.model_dump_json() + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": { + "input_tokens": 100, + "output_tokens": 100, + }, + } @pytest.mark.asyncio async def test_generate(self, mock_mistral: MagicMock) -> None: @@ -95,7 +112,10 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: llm._aclient = mock_mistral mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + choices=[ + Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) + ], + usage=Mock(prompt_tokens=10, completion_tokens=10, total_tokens=20), ) llm._aclient.chat = Mock( complete_async=AsyncMock(return_value=mocked_completion) @@ -103,7 +123,7 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -114,6 +134,12 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 10, "output_tokens": 10}, + } + ] @pytest.mark.parametrize( "structured_output, dump", diff --git a/tests/unit/llms/test_ollama.py b/tests/unit/llms/test_ollama.py index db31d9cb07..34f32b82a4 100644 --- a/tests/unit/llms/test_ollama.py +++ b/tests/unit/llms/test_ollama.py @@ -33,11 +33,13 @@ async def test_agenerate(self, mock_ollama: MagicMock) -> None: llm._aclient = mock_ollama mocked_completion = { - "message": {"content": " Aenean hendrerit aliquam velit. ..."} + "message": {"content": "Aenean hendrerit aliquam velit..."}, + "prompt_eval_count": 10, + "eval_count": 10, } llm._aclient.chat = AsyncMock(return_value=mocked_completion) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "system", "content": ""}, { @@ -46,6 +48,10 @@ async def test_agenerate(self, mock_ollama: MagicMock) -> None: }, ] ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": 10, "output_tokens": 10}, + } @pytest.mark.asyncio async def test_generate(self, mock_ollama: MagicMock) -> None: @@ -53,14 +59,16 @@ async def test_generate(self, mock_ollama: MagicMock) -> None: llm._aclient = mock_ollama mocked_completion = { - "message": {"content": " Aenean hendrerit aliquam velit. ..."} + "message": {"content": "Aenean hendrerit aliquam velit..."}, + "prompt_eval_count": 10, + "eval_count": 10, } llm._aclient.chat = AsyncMock(return_value=mocked_completion) nest_asyncio.apply() - llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -71,6 +79,12 @@ async def test_generate(self, mock_ollama: MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": {"input_tokens": 10, "output_tokens": 10}, + } + ] def test_serialization(self, _: MagicMock) -> None: llm = OllamaLLM(model="notus") # type: ignore diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py index af743409fe..463d850f77 100644 --- a/tests/unit/llms/test_openai.py +++ b/tests/unit/llms/test_openai.py @@ -103,7 +103,13 @@ async def test_agenerate_structured( import tiktoken llm._tokenizer = tiktoken.encoding_for_model(self.model_id) - sample_user = DummyUserDetail(name="John Doe", age=30) + + mocked_usage = MagicMock( + usage=MagicMock(prompt_tokens=100, completion_tokens=100), + ) + sample_user = DummyUserDetail( + name="John Doe", age=30, _raw_response=mocked_usage + ) llm._aclient.chat.completions.create = AsyncMock(return_value=sample_user) @@ -116,10 +122,10 @@ async def test_agenerate_structured( }, ] ) - assert isinstance(generation, dict) - generations = generation["generations"] - assert generations[0] == sample_user.model_dump_json() - assert generation["statistics"] == {"input_tokens": 10, "output_tokens": 12} + assert generation == { + "generations": [sample_user.model_dump_json()], + "statistics": {"input_tokens": 100, "output_tokens": 100}, + } @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" diff --git a/tests/unit/llms/test_vertexai.py b/tests/unit/llms/test_vertexai.py index 38f5933849..c2bb14c595 100644 --- a/tests/unit/llms/test_vertexai.py +++ b/tests/unit/llms/test_vertexai.py @@ -41,9 +41,10 @@ async def test_agenerate(self, mock_generative_model: MagicMock) -> None: llm._generation_config_class = GenerationConfig mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + candidates=[Mock(text=" Aenean hendrerit aliquam velit. ...")], + usage_metadata=Mock(prompt_token_count=10, candidates_token_count=10), ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) + llm._aclient.generate_content_async = AsyncMock(return_value=mocked_completion) with pytest.raises( ValueError, match="`VertexAILLM only supports the roles 'user' or 'model'." @@ -58,7 +59,7 @@ async def test_agenerate(self, mock_generative_model: MagicMock) -> None: ] ) - await llm.agenerate( + result = await llm.agenerate( input=[ {"role": "model", "content": ""}, { @@ -67,6 +68,10 @@ async def test_agenerate(self, mock_generative_model: MagicMock) -> None: }, ] ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 10, "output_tokens": 10}, + } @pytest.mark.asyncio async def test_generate(self, mock_generative_model: MagicMock) -> None: @@ -77,9 +82,10 @@ async def test_generate(self, mock_generative_model: MagicMock) -> None: llm._generation_config_class = GenerationConfig mocked_completion = Mock( - choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + candidates=[Mock(text=" Aenean hendrerit aliquam velit. ...")], + usage_metadata=Mock(prompt_token_count=10, candidates_token_count=10), ) - llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) + llm._aclient.generate_content_async = AsyncMock(return_value=mocked_completion) nest_asyncio.apply() @@ -98,7 +104,7 @@ async def test_generate(self, mock_generative_model: MagicMock) -> None: ] ) - llm.generate( + result = llm.generate( inputs=[ [ {"role": "model", "content": "I am a model."}, @@ -109,6 +115,12 @@ async def test_generate(self, mock_generative_model: MagicMock) -> None: ] ] ) + assert result == [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": 10, "output_tokens": 10}, + } + ] def test_serialization(self, _: MagicMock) -> None: llm = VertexAILLM(model="gemini-1.0-pro") diff --git a/tests/unit/llms/utils.py b/tests/unit/llms/utils.py index 7b899253bb..1888388f6e 100644 --- a/tests/unit/llms/utils.py +++ b/tests/unit/llms/utils.py @@ -12,9 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pydantic import BaseModel +from typing import Any + +from pydantic import BaseModel, PrivateAttr class DummyUserDetail(BaseModel): name: str age: int + _raw_response: Any = PrivateAttr() + + def __init__(self, **data): + super().__init__(**data) + self._raw_response = data.get("_raw_response") From 4cf0e2f9940ae303995382c7b3551bedb65c11bf Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 15 Oct 2024 17:02:58 +0200 Subject: [PATCH 09/35] Passing tests for inference endpoints --- .../llms/huggingface/inference_endpoints.py | 89 ++++++++------ .../huggingface/test_inference_endpoints.py | 112 ++++++++++++------ 2 files changed, 133 insertions(+), 68 deletions(-) diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 3566228f56..3eef33ce54 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -42,6 +42,12 @@ if TYPE_CHECKING: from huggingface_hub import AsyncInferenceClient + from huggingface_hub.inference._generated.types.chat_completion import ( + ChatCompletionOutput, + ) + from huggingface_hub.inference._generated.types.text_generation import ( + TextGenerationOutput, + ) from transformers import PreTrainedTokenizer @@ -387,12 +393,12 @@ async def _generate_with_text_generation( return_full_text: bool = False, seed: Optional[int] = None, watermark: bool = False, - ) -> Union[str, None]: + ) -> GenerateOutput: structured_output = self._get_structured_output(input) completion = None try: - completion = await self._aclient.text_generation( # type: ignore + completion: "TextGenerationOutput" = await self._aclient.text_generation( # type: ignore prompt=self.prepare_input(input), # type: ignore max_new_tokens=max_new_tokens, do_sample=do_sample, @@ -409,13 +415,24 @@ async def _generate_with_text_generation( seed=seed or random.randint(0, sys.maxsize), watermark=watermark, grammar=structured_output, # type: ignore + details=True, ) except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - return completion + # NOTE: I cannot see the input tokens returned, and given that the model can be private, I cannot + # count the input tokens + return { + "generations": [completion.generated_text], + "statistics": { + "input_tokens": 0, + "output_tokens": completion.details.generated_tokens + if completion.details + else 0, + }, + } async def _generate_with_chat_completion( self, @@ -431,10 +448,10 @@ async def _generate_with_chat_completion( tool_prompt: Optional[str] = None, tools: Optional[List[Dict[str, Any]]] = None, top_p: Optional[float] = None, - ) -> Union[str, None]: + ) -> GenerateOutput: message = None try: - completion = await self._aclient.chat_completion( # type: ignore + completion: "ChatCompletionOutput" = await self._aclient.chat_completion( # type: ignore messages=input, # type: ignore max_tokens=max_new_tokens, frequency_penalty=frequency_penalty, @@ -461,7 +478,13 @@ async def _generate_with_chat_completion( f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - return message + return { + "generations": [message], + "statistics": { + "input_tokens": completion.usage.prompt_tokens, + "output_tokens": completion.usage.completion_tokens, + }, + } def _check_stop_sequences( self, @@ -574,37 +597,33 @@ async def agenerate( # type: ignore stop_sequences = self._check_stop_sequences(stop_sequences) if self.tokenizer_id is None: - return [ - await self._generate_with_chat_completion( - input=input, # type: ignore - max_new_tokens=max_new_tokens, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - presence_penalty=presence_penalty, - seed=seed, - stop_sequences=stop_sequences, - temperature=temperature, - tool_choice=tool_choice, - tool_prompt=tool_prompt, - tools=tools, - top_p=top_p, - ) - ] - - return [ - await self._generate_with_text_generation( - input=input, + return await self._generate_with_chat_completion( + input=input, # type: ignore max_new_tokens=max_new_tokens, - do_sample=do_sample, - typical_p=typical_p, - repetition_penalty=repetition_penalty, frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + presence_penalty=presence_penalty, + seed=seed, + stop_sequences=stop_sequences, temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, top_p=top_p, - top_k=top_k, - stop_sequences=stop_sequences, - return_full_text=return_full_text, - seed=seed, - watermark=watermark, ) - ] + + return await self._generate_with_text_generation( + input=input, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + typical_p=typical_p, + repetition_penalty=repetition_penalty, + frequency_penalty=frequency_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stop_sequences=stop_sequences, + return_full_text=return_full_text, + seed=seed, + watermark=watermark, + ) diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 4205e1fd26..8887789d4b 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -130,17 +130,29 @@ async def test_agenerate_with_text_generation( llm.load() llm._aclient.text_generation = AsyncMock( - return_value=" Aenean hendrerit aliquam velit. ..." + return_value=MagicMock( + generated_text="Aenean hendrerit aliquam velit...", + details=MagicMock( + generated_tokens=66, + ), + ) ) - assert await llm.agenerate( + result = await llm.agenerate( input=[ { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [" Aenean hendrerit aliquam velit. ..."] + ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": { + "input_tokens": 0, + "output_tokens": 66, + }, + } @pytest.mark.asyncio async def test_agenerate_with_chat_completion( @@ -173,14 +185,21 @@ async def test_agenerate_with_chat_completion( ) ) - assert await llm.agenerate( + result = await llm.agenerate( input=[ { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [" Aenean hendrerit aliquam velit. ..."] + ) + assert result == { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": { + "input_tokens": 18, + "output_tokens": 66, + }, + } @pytest.mark.asyncio async def test_agenerate_with_chat_completion_fails( @@ -213,29 +232,54 @@ async def test_agenerate_with_chat_completion_fails( ) ) - assert await llm.agenerate( + result = await llm.agenerate( input=[ { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [None] + ) + assert result == { + "generations": [None], + "statistics": { + "input_tokens": 18, + "output_tokens": 66, + }, + } @pytest.mark.asyncio async def test_generate(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", - tokenizer_id="distilabel-internal-testing/tiny-random-mistral", + # tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) llm.load() - llm._aclient.text_generation = AsyncMock( - return_value=" Aenean hendrerit aliquam velit. ..." + llm._aclient.chat_completion = AsyncMock( # type: ignore + return_value=ChatCompletionOutput( # type: ignore + choices=[ + ChatCompletionOutputComplete( + finish_reason="eos_token", + index=0, + message=ChatCompletionOutputMessage( + role="assistant", + content=None, + ), + ) + ], + created=1721045246, + id="", + model="meta-llama/Meta-Llama-3-70B-Instruct", + system_fingerprint="2.1.1-dev0-sha-4327210", + usage=ChatCompletionOutputUsage( + completion_tokens=66, prompt_tokens=18, total_tokens=84 + ), + ) ) nest_asyncio.apply() - assert llm.generate( + result = llm.generate( inputs=[ [ {"role": "system", "content": ""}, @@ -245,7 +289,16 @@ async def test_generate(self, mock_inference_client: MagicMock) -> None: }, ] ] - ) == [[" Aenean hendrerit aliquam velit. ..."]] + ) + assert result == [ + { + "generations": [None], + "statistics": { + "input_tokens": 18, + "output_tokens": 66, + }, + } + ] @pytest.mark.asyncio async def test_agenerate_with_structured_output( @@ -259,39 +312,32 @@ async def test_agenerate_with_structured_output( llm.load() llm._aclient.text_generation = AsyncMock( - return_value=" Aenean hendrerit aliquam velit. ..." + return_value=MagicMock( + generated_text="Aenean hendrerit aliquam velit...", + details=MagicMock( + generated_tokens=66, + ), + ) ) - # Since there's a pseudo-random number within the generation kwargs, we set the seed # here first to ensure reproducibility within the tests random.seed(42) - assert await llm.agenerate( + result = await llm.agenerate( input=[ { "role": "user", "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ) == [" Aenean hendrerit aliquam velit. ..."] - - kwargs = { - "prompt": " [INST] Lorem ipsum dolor sit amet, consectetur adipiscing elit. [/INST]", - "max_new_tokens": 128, - "do_sample": False, - "typical_p": None, - "repetition_penalty": None, - "frequency_penalty": None, - "temperature": 1.0, - "top_p": None, - "top_k": None, - "stop_sequences": None, - "return_full_text": False, - "seed": 2053695854357871005, # pre-computed random value with `random.seed(42)` - "watermark": False, - "grammar": {"type": "regex", "value": "\\b[A-Z][a-z]*\\b"}, + ) + assert result == { + "generations": ["Aenean hendrerit aliquam velit..."], + "statistics": { + "input_tokens": 0, + "output_tokens": 66, + }, } - llm._aclient.text_generation.assert_called_with(**kwargs) # type: ignore def test_serialization(self, mock_inference_client: MagicMock) -> None: llm = InferenceEndpointsLLM( From 1b6d15c715771ac765f80e2c736c4f937fa560d8 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 16 Oct 2024 12:22:21 +0200 Subject: [PATCH 10/35] Testing vLLM with statistics --- .../llms/huggingface/transformers.py | 5 +-- src/distilabel/llms/llamacpp.py | 6 ++-- src/distilabel/llms/typing.py | 12 ++++++- src/distilabel/llms/vllm.py | 34 +++++++++++++++++-- 4 files changed, 46 insertions(+), 11 deletions(-) diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index d2c7b83ed0..8be23fcea2 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -246,10 +246,7 @@ def generate( # type: ignore "generations": output, "statistics": { "input_tokens": [ - compute_tokens( - row["content"], self._pipeline.tokenizer.encode - ) - for row in input + compute_tokens(input, self._pipeline.tokenizer.encode) ], "output_tokens": [ compute_tokens(row, self._pipeline.tokenizer.encode) diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/llms/llamacpp.py index 69518b924c..1f1176390d 100644 --- a/src/distilabel/llms/llamacpp.py +++ b/src/distilabel/llms/llamacpp.py @@ -219,7 +219,6 @@ def generate( # type: ignore structured_output = self.structured_output outputs = [] - input_tokens = [] output_tokens = [] for _ in range(num_generations): # NOTE(plaguss): There seems to be a bug in how the logits processor @@ -243,13 +242,14 @@ def generate( # type: ignore ) ) outputs.append(chat_completions["choices"][0]["message"]["content"]) - input_tokens.append(chat_completions["usage"]["prompt_tokens"]) output_tokens.append(chat_completions["usage"]["completion_tokens"]) batch_outputs.append( { "generations": outputs, "statistics": { - "input_tokens": input_tokens, + "input_tokens": [ + chat_completions["usage"]["prompt_tokens"] + ], # Should be the same for the n_generations "output_tokens": output_tokens, }, } diff --git a/src/distilabel/llms/typing.py b/src/distilabel/llms/typing.py index a4c3e13068..512c76b471 100644 --- a/src/distilabel/llms/typing.py +++ b/src/distilabel/llms/typing.py @@ -15,7 +15,17 @@ from typing import TYPE_CHECKING, Any, Dict, List, TypedDict, TypeVar, Union LLMOutput = List[Union[str, None]] -LLMStatistics = Dict[str, Any] + + +class TokenCount(TypedDict): + input_tokens: List[int] + output_tokens: List[int] + + +LLMStatistics = Union[TokenCount, Dict[str, Any]] +"""Initially the LLMStatistics will contain the token count, but can have more variables. +They can be added once we have them defined for every LLM. +""" class GenerateOutput(TypedDict): diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index 19212755d4..43468b3a3f 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -33,6 +33,7 @@ from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.openai import OpenAILLM +from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType @@ -41,6 +42,7 @@ from openai import OpenAI # noqa from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM + from vllm.outputs import RequestOutputs from distilabel.steps.tasks.typing import StandardInput @@ -359,6 +361,7 @@ def generate( # type: ignore logits_processors.append(self._structured_output_logits_processor) batched_outputs = [] + generations = [] for prepared_inputs, structured_output in prepared_batches: if structured_output: @@ -383,7 +386,7 @@ def generate( # type: ignore **extra_sampling_params, ) - batch_outputs = self._model.generate( + batch_outputs: "RequestOutputs" = self._model.generate( prepared_inputs, sampling_params, use_tqdm=False, # type: ignore @@ -392,6 +395,20 @@ def generate( # type: ignore batched_outputs += [ [output.text for output in outputs.outputs] for outputs in batch_outputs ] + for input, outputs in zip(prepared_inputs, batch_outputs): + generations.append( + { + "generations": [output.text for output in outputs.outputs], + "statistics": { + "input_tokens": [ + compute_tokens(input, self._tokenizer.encode) + ], + "output_tokens": [ + len(output.token_ids) for output in outputs.outputs + ], + }, + } + ) # If logits_processor is set, we need to sort the outputs back to the original order # (would be needed only if we have multiple structured outputs in the dataset) @@ -399,7 +416,8 @@ def generate( # type: ignore batched_outputs = _sort_batches( batched_outputs, sorted_indices, num_generations=num_generations ) - return batched_outputs + # return batched_outputs + return generations def _prepare_structured_output( self, structured_output: Optional[OutlinesStructuredOutputType] = None @@ -604,7 +622,17 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(text) - return generations + return { + "generations": generations, + "statistics": { + "input_tokens": completion.usage.prompt_tokens + if completion.usage + else 0, + "output_tokens": completion.usage.completion_tokens + if completion.usage + else 0, + }, + } def _sort_batches( From 8923880f85d13eb3175037eb747adcf2c1b6455d Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 16 Oct 2024 13:34:17 +0200 Subject: [PATCH 11/35] Refactor statistics module to utils and output preparation to avoid code duplication --- src/distilabel/llms/base.py | 6 ++- src/distilabel/llms/cohere.py | 2 +- .../llms/huggingface/transformers.py | 2 +- src/distilabel/llms/openai.py | 37 ++++++++----------- .../llms/{statistics.py => utils.py} | 29 ++++++++++++++- src/distilabel/llms/vllm.py | 2 +- .../llms/huggingface/test_transformers.py | 11 ++++++ tests/unit/llms/test_vllm.py | 14 ++++++- 8 files changed, 75 insertions(+), 28 deletions(-) rename src/distilabel/llms/{statistics.py => utils.py} (61%) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index 170fcc3654..3001a477e0 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -459,8 +459,12 @@ async def _agenerate( ) for input in inputs ] - return await asyncio.gather(*tasks) + result = await asyncio.gather(*tasks) + # TODO: Update the object returned to be the same as in synchronous LLMs with batches. + return result + + # TODO: Update the object returned to be the same as in synchronous LLMs with batches. tasks = [ asyncio.create_task(self.agenerate(input=input, **kwargs)) for input in inputs diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index c4be6c8f9a..c630a1add0 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -28,8 +28,8 @@ from tokenizers import Tokenizer from distilabel.llms.base import AsyncLLM -from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import compute_tokens from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( FormattedInput, diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 8be23fcea2..90348fdc18 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -20,8 +20,8 @@ from distilabel.llms.base import LLM from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin -from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import compute_tokens from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index 3c4a9688da..dc60b81ceb 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -23,6 +23,7 @@ from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType @@ -32,6 +33,8 @@ from openai.types import FileObject as OpenAIFileObject from openai.types.chat import ChatCompletion as OpenAIChatCompletion + from distilabel.llms.typing import LLMStatistics + _OPENAI_API_KEY_ENV_VAR_NAME = "OPENAI_API_KEY" _OPENAI_BATCH_API_MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB @@ -307,17 +310,10 @@ async def agenerate( # type: ignore completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore if structured_output: - return { - "generations": [completion.model_dump_json()], - "statistics": { - "input_tokens": completion._raw_response.usage.prompt_tokens - if completion._raw_response - else 0, - "output_tokens": completion._raw_response.usage.completion_tokens - if completion._raw_response - else 0, - }, - } + return prepare_output( + [completion.model_dump_json()], + **self._get_llm_statistics(completion._raw_response), + ) return self._generations_from_openai_completion(completion) @@ -341,17 +337,7 @@ def _generations_from_openai_completion( ) generations.append(content) - return { - "generations": generations, - "statistics": { - "input_tokens": completion.usage.prompt_tokens - if completion.usage - else 0, - "output_tokens": completion.usage.completion_tokens - if completion.usage - else 0, - }, - } + return prepare_output(generations, **self._get_llm_statistics(completion)) def offline_batch_generate( self, @@ -698,3 +684,10 @@ def _name_for_openai_files(self, file_no: int) -> str: return f"distilabel-pipeline-fileno-{file_no}.jsonl" return f"distilabel-pipeline-{envs.DISTILABEL_PIPELINE_NAME}-{envs.DISTILABEL_PIPELINE_CACHE_ID}-fileno-{file_no}.jsonl" + + @staticmethod + def _get_llm_statistics(completion: "OpenAIChatCompletion") -> "LLMStatistics": + return { + "input_tokens": completion.usage.prompt_tokens if completion else 0, + "output_tokens": completion.usage.completion_tokens if completion else 0, + } diff --git a/src/distilabel/llms/statistics.py b/src/distilabel/llms/utils.py similarity index 61% rename from src/distilabel/llms/statistics.py rename to src/distilabel/llms/utils.py index efbaa50716..b2cf307ee8 100644 --- a/src/distilabel/llms/statistics.py +++ b/src/distilabel/llms/utils.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Union +from typing import TYPE_CHECKING, Callable, List, Optional, Union from distilabel.steps.tasks.typing import ChatType +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput, LLMOutput + def compute_tokens( text_or_messages: Union[str, ChatType], tokenizer: Callable[[str], List[int]] @@ -36,3 +39,27 @@ def compute_tokens( text = text_or_messages return len(tokenizer(text)) + + +def prepare_output( + generations: "LLMOutput", + input_tokens: Optional[List[int]] = None, + output_tokens: Optional[List[int]] = None, +) -> "GenerateOutput": + """Helper function to prepare the output of the LLM. + + Args: + generations: The outputs from an LLM. + input_tokens: The number of tokens of the inputs. Defaults to [0]. + output_tokens: The number of tokens of the LLM response. Defaults to [0]. + + Returns: + Output generation from an LLM. + """ + return { + "generations": generations, + "statistics": { + "input_tokens": input_tokens or 0, + "output_tokens": input_tokens or 0, + }, + } diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index 43468b3a3f..c6aab8ede9 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -33,8 +33,8 @@ from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.openai import OpenAILLM -from distilabel.llms.statistics import compute_tokens from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import compute_tokens from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType diff --git a/tests/unit/llms/huggingface/test_transformers.py b/tests/unit/llms/huggingface/test_transformers.py index 8c6e25e96f..bc6cf46912 100644 --- a/tests/unit/llms/huggingface/test_transformers.py +++ b/tests/unit/llms/huggingface/test_transformers.py @@ -53,6 +53,17 @@ def test_generate(self, transformers_llm: TransformersLLM) -> None: ], num_generations=3, ) + # Note: It returns the following structure: + # [ + # { + # "generations": [text1, text2, text3], # As much as num_generations + # "statistics": { + # "input_tokens": [7], + # "output_tokens": [128, 128, 128], # The sum of the tokens of the generated texts + # }, + # }, + # {...} + # ] assert len(responses) == 2 generations = responses[0]["generations"] statistics = responses[0]["statistics"] diff --git a/tests/unit/llms/test_vllm.py b/tests/unit/llms/test_vllm.py index c1df505126..7f29790873 100644 --- a/tests/unit/llms/test_vllm.py +++ b/tests/unit/llms/test_vllm.py @@ -21,6 +21,7 @@ from openai.types import Model from openai.types.completion import Completion from openai.types.completion_choice import CompletionChoice +from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel from distilabel.llms import vLLM @@ -240,6 +241,11 @@ async def test_agenerate( text="I'm fine thank you sir", ), ], + usage=CompletionUsage( + completion_tokens=10, + prompt_tokens=10, + total_tokens=20, + ), ) ) @@ -247,4 +253,10 @@ async def test_agenerate( input=[{"role": "user", "content": "Hi, how are you?"}] ) - assert generations == ["I'm fine thank you", "I'm fine thank you sir"] + assert generations == { + "generations": ["I'm fine thank you", "I'm fine thank you sir"], + "statistics": { + "input_tokens": 10, + "output_tokens": 10, + }, + } From 608e8b6b2132da68c781d1bd7a8eafd28ac0fd72 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 16 Oct 2024 15:56:32 +0200 Subject: [PATCH 12/35] Refactor to remove code duplication --- src/distilabel/llms/anthropic.py | 27 ++++++----- src/distilabel/llms/cohere.py | 48 +++++++++---------- src/distilabel/llms/groq.py | 35 ++++++-------- .../llms/huggingface/inference_endpoints.py | 31 ++++++------ .../llms/huggingface/transformers.py | 25 +++++----- src/distilabel/llms/litellm.py | 40 +++++++--------- src/distilabel/llms/llamacpp.py | 15 +++--- src/distilabel/llms/mistral.py | 28 ++++++----- src/distilabel/llms/ollama.py | 18 +++---- src/distilabel/llms/openai.py | 4 +- src/distilabel/llms/utils.py | 2 +- src/distilabel/llms/vertexai.py | 17 ++++--- src/distilabel/llms/vllm.py | 43 ++++++++--------- .../huggingface/test_inference_endpoints.py | 20 ++++---- tests/unit/llms/test_anthropic.py | 8 ++-- tests/unit/llms/test_cohere.py | 8 ++-- tests/unit/llms/test_groq.py | 8 ++-- tests/unit/llms/test_litellm.py | 4 +- tests/unit/llms/test_llamacpp.py | 1 - tests/unit/llms/test_mistral.py | 8 ++-- tests/unit/llms/test_ollama.py | 4 +- tests/unit/llms/test_openai.py | 18 +++---- tests/unit/llms/test_vertexai.py | 4 +- 23 files changed, 196 insertions(+), 220 deletions(-) diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/llms/anthropic.py index 5e19a75105..d050927d8b 100644 --- a/src/distilabel/llms/anthropic.py +++ b/src/distilabel/llms/anthropic.py @@ -29,6 +29,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( FormattedInput, @@ -41,6 +42,8 @@ from anthropic import AsyncAnthropic from anthropic.types import Message + from distilabel.llms.typing import LLMStatistics + _ANTHROPIC_API_KEY_ENV_VAR_NAME = "ANTHROPIC_API_KEY" @@ -267,24 +270,22 @@ async def agenerate( # type: ignore **kwargs ) # type: ignore if structured_output: - raw_response = completion._raw_response - return { - "generations": [completion.model_dump_json()], - "statistics": { - "input_tokens": raw_response.usage.input_tokens, - "output_tokens": raw_response.usage.output_tokens, - }, - } + # raw_response = completion._raw_response + return prepare_output( + [completion.model_dump_json()], + **self._get_llm_statistics(completion._raw_response), + ) if (content := completion.content[0].text) is None: self._logger.warning( f"Received no response using Anthropic client (model: '{self.model}')." f" Finish reason was: {completion.stop_reason}" ) + return prepare_output([content], **self._get_llm_statistics(completion)) + + @staticmethod + def _get_llm_statistics(completion: "Message") -> "LLMStatistics": return { - "generations": [content], - "statistics": { - "input_tokens": completion.usage.input_tokens, - "output_tokens": completion.usage.output_tokens, - }, + "input_tokens": [completion.usage.input_tokens], + "output_tokens": [completion.usage.output_tokens], } diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index c630a1add0..8dba8a8c90 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -29,7 +29,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput -from distilabel.llms.utils import compute_tokens +from distilabel.llms.utils import compute_tokens, prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( FormattedInput, @@ -40,6 +40,8 @@ from cohere import AsyncClient, ChatMessage, Message from pydantic import BaseModel + from distilabel.llms.typing import LLMStatistics + _COHERE_API_KEY_ENV_VAR_NAME = "COHERE_API_KEY" @@ -290,36 +292,32 @@ async def agenerate( # type: ignore response: Union["Message", "BaseModel"] = await self._aclient.chat(**kwargs) # type: ignore if structured_output: - # TODO: Refactor the dict response, it's quite similar in many LLMs - str_response = response.model_dump_json() - return { - "generations": [str_response], - "statistics": { - "input_tokens": compute_tokens(input, self._tokenizer.encode), - "output_tokens": compute_tokens( - orjson.dumps(str_response).decode("utf-8"), - self._tokenizer.encode, - ), - }, - } + return prepare_output( + [response.model_dump_json()], + **self._get_llm_statistics( + input, orjson.dumps(response.model_dump_json()).decode("utf-8") + ), # type: ignore + ) if (text := response.text) == "": self._logger.warning( # type: ignore f"Received no response using Cohere client (model: '{self.model}')." f" Finish reason was: {response.finish_reason}" ) - return { - "generations": [None], - "statistics": { - "input_tokens": compute_tokens(input, self._tokenizer.encode), - "output_tokens": 0, - }, - } + return prepare_output( + [None], + **self._get_llm_statistics(input, ""), + ) + + return prepare_output( + [text], + **self._get_llm_statistics(input, text), + ) + def _get_llm_statistics( + self, input: FormattedInput, output: str + ) -> "LLMStatistics": return { - "generations": [text], - "statistics": { - "input_tokens": compute_tokens(input, self._tokenizer.encode), - "output_tokens": compute_tokens(text, self._tokenizer.encode), - }, + "input_tokens": [compute_tokens(input, self._tokenizer.encode)], + "output_tokens": [compute_tokens(output, self._tokenizer.encode)], } diff --git a/src/distilabel/llms/groq.py b/src/distilabel/llms/groq.py index 2143a0ce0f..a66d735437 100644 --- a/src/distilabel/llms/groq.py +++ b/src/distilabel/llms/groq.py @@ -19,6 +19,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.steps.base import RuntimeParameter from distilabel.steps.tasks.typing import ( FormattedInput, @@ -27,6 +28,9 @@ if TYPE_CHECKING: from groq import AsyncGroq + from groq.types.chat.chat_completion import ChatCompletion + + from distilabel.llms.typing import LLMStatistics _GROQ_API_BASE_URL_ENV_VAR_NAME = "GROQ_BASE_URL" @@ -227,18 +231,11 @@ async def agenerate( # type: ignore completion = await self._aclient.chat.completions.create(**kwargs) # type: ignore if structured_output: - raw_response = completion._raw_response - return { - "generations": [completion.model_dump_json()], - "statistics": { - "input_tokens": raw_response.usage.prompt_tokens - if raw_response.usage - else 0, - "output_tokens": raw_response.usage.completion_tokens - if raw_response.usage - else 0, - }, - } + return prepare_output( + [completion.model_dump_json()], + **self._get_llm_statistics(completion._raw_response), + ) + generations = [] for choice in completion.choices: if (content := choice.message.content) is None: @@ -247,15 +244,11 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(content) + return prepare_output(generations, **self._get_llm_statistics(completion)) + @staticmethod + def _get_llm_statistics(completion: "ChatCompletion") -> "LLMStatistics": return { - "generations": generations, - "statistics": { - "input_tokens": completion.usage.prompt_tokens - if completion.usage - else 0, - "output_tokens": completion.usage.completion_tokens - if completion.usage - else 0, - }, + "input_tokens": [completion.usage.prompt_tokens if completion else 0], + "output_tokens": [completion.usage.completion_tokens if completion else 0], } diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/llms/huggingface/inference_endpoints.py index 3eef33ce54..a07a6efd7c 100644 --- a/src/distilabel/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/llms/huggingface/inference_endpoints.py @@ -32,6 +32,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( FormattedInput, @@ -423,16 +424,14 @@ async def _generate_with_text_generation( f" Finish reason was: {e}" ) # NOTE: I cannot see the input tokens returned, and given that the model can be private, I cannot - # count the input tokens - return { - "generations": [completion.generated_text], - "statistics": { - "input_tokens": 0, - "output_tokens": completion.details.generated_tokens - if completion.details - else 0, - }, - } + # count them... + return prepare_output( + [completion.generated_text], + input_tokens=[0], + output_tokens=[ + completion.details.generated_tokens if completion.details else 0 + ], + ) async def _generate_with_chat_completion( self, @@ -478,13 +477,11 @@ async def _generate_with_chat_completion( f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - return { - "generations": [message], - "statistics": { - "input_tokens": completion.usage.prompt_tokens, - "output_tokens": completion.usage.completion_tokens, - }, - } + return prepare_output( + [message], + input_tokens=[completion.usage.prompt_tokens], + output_tokens=[completion.usage.completion_tokens], + ) def _check_stop_sequences( self, diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 90348fdc18..b5db5104a6 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -21,7 +21,7 @@ from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput -from distilabel.llms.utils import compute_tokens +from distilabel.llms.utils import compute_tokens, prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR @@ -242,19 +242,18 @@ def generate( # type: ignore result = [] for input, output in zip(inputs, llm_output): result.append( - { - "generations": output, - "statistics": { - "input_tokens": [ - compute_tokens(input, self._pipeline.tokenizer.encode) - ], - "output_tokens": [ - compute_tokens(row, self._pipeline.tokenizer.encode) - for row in output - ], - }, - } + prepare_output( + output, + input_tokens=[ + compute_tokens(input, self._pipeline.tokenizer.encode) + ], + output_tokens=[ + compute_tokens(row, self._pipeline.tokenizer.encode) + for row in output + ], + ) ) + return result def get_last_hidden_states( diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/llms/litellm.py index 5c4646c6ef..384837a8a9 100644 --- a/src/distilabel/llms/litellm.py +++ b/src/distilabel/llms/litellm.py @@ -20,6 +20,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType @@ -258,24 +259,23 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: raise e generations = [] - input_tokens = token_counter(model=self.model, messages=input) - output_tokens = 0 + input_tokens = [token_counter(model=self.model, messages=input)] + output_tokens = [] if self.structured_output: for choice in choices: generations.append(choice.model_dump_json()) - output_tokens += token_counter( - model=self.model, - text=orjson.dumps(choice.model_dump_json()).decode("utf-8"), + output_tokens.append( + token_counter( + model=self.model, + text=orjson.dumps(choice.model_dump_json()).decode("utf-8"), + ) ) - - return { - "generations": generations, - "statistics": { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - }, - } + return prepare_output( + generations, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) for choice in choices: if (content := choice.message.content) is None: @@ -284,12 +284,8 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: f" Finish reason was: {choice.finish_reason}" ) generations.append(content) - output_tokens += token_counter(model=self.model, text=content) - - return { - "generations": generations, - "statistics": { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - }, - } + output_tokens.append(token_counter(model=self.model, text=content)) + + return prepare_output( + generations, input_tokens=input_tokens, output_tokens=output_tokens + ) diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/llms/llamacpp.py index 1f1176390d..b7e83dc126 100644 --- a/src/distilabel/llms/llamacpp.py +++ b/src/distilabel/llms/llamacpp.py @@ -18,6 +18,7 @@ from distilabel.llms.base import LLM from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType @@ -244,15 +245,11 @@ def generate( # type: ignore outputs.append(chat_completions["choices"][0]["message"]["content"]) output_tokens.append(chat_completions["usage"]["completion_tokens"]) batch_outputs.append( - { - "generations": outputs, - "statistics": { - "input_tokens": [ - chat_completions["usage"]["prompt_tokens"] - ], # Should be the same for the n_generations - "output_tokens": output_tokens, - }, - } + prepare_output( + outputs, + input_tokens=[chat_completions["usage"]["prompt_tokens"]], + output_tokens=output_tokens, + ) ) return batch_outputs diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/llms/mistral.py index 92a88ce71b..c3147014fc 100644 --- a/src/distilabel/llms/mistral.py +++ b/src/distilabel/llms/mistral.py @@ -19,6 +19,7 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import ( FormattedInput, @@ -27,6 +28,9 @@ if TYPE_CHECKING: from mistralai import Mistral + from mistralai.models.chatcompletionresponse import ChatCompletionResponse + + from distilabel.llms.typing import LLMStatistics _MISTRALAI_API_KEY_ENV_VAR_NAME = "MISTRAL_API_KEY" @@ -221,14 +225,10 @@ async def agenerate( # type: ignore completion = await self._aclient.chat.complete_async(**kwargs) # type: ignore if structured_output: - raw_response = completion._raw_response - return { - "generations": [completion.model_dump_json()], - "statistics": { - "input_tokens": raw_response.usage.prompt_tokens, - "output_tokens": raw_response.usage.completion_tokens, - }, - } + return prepare_output( + [completion.model_dump_json()], + **self._get_llm_statistics(completion._raw_response), + ) for choice in completion.choices: if (content := choice.message.content) is None: @@ -237,10 +237,12 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(content) + + return prepare_output(generations, **self._get_llm_statistics(completion)) + + @staticmethod + def _get_llm_statistics(completion: "ChatCompletionResponse") -> "LLMStatistics": return { - "generations": generations, - "statistics": { - "input_tokens": completion.usage.prompt_tokens, - "output_tokens": completion.usage.completion_tokens, - }, + "input_tokens": [completion.usage.prompt_tokens], + "output_tokens": [completion.usage.completion_tokens], } diff --git a/src/distilabel/llms/ollama.py b/src/distilabel/llms/ollama.py index 37a808ed4f..107d3bf7a3 100644 --- a/src/distilabel/llms/ollama.py +++ b/src/distilabel/llms/ollama.py @@ -19,12 +19,15 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput if TYPE_CHECKING: from ollama import AsyncClient + from distilabel.llms.typing import LLMStatistics + # Copied from `ollama._types.Options` class Options(TypedDict, total=False): @@ -159,8 +162,6 @@ async def agenerate( # type: ignore A list of strings as completion for the given input. """ text = None - input_tokens = 0 - output_tokens = 0 try: completion: Dict[str, Any] = await self._aclient.chat( # type: ignore model=self.model, @@ -171,18 +172,17 @@ async def agenerate( # type: ignore keep_alive=keep_alive, ) text = completion["message"]["content"] - input_tokens = completion["prompt_eval_count"] - output_tokens = completion["eval_count"] except Exception as e: self._logger.warning( # type: ignore f"⚠️ Received no response using Ollama client (model: '{self.model_name}')." f" Finish reason was: {e}" ) + return prepare_output([text], **self._get_llm_statistics(completion)) + + @staticmethod + def _get_llm_statistics(completion: Dict[str, Any]) -> "LLMStatistics": return { - "generations": [text], - "statistics": { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - }, + "input_tokens": [completion["prompt_eval_count"]], + "output_tokens": [completion["eval_count"]], } diff --git a/src/distilabel/llms/openai.py b/src/distilabel/llms/openai.py index dc60b81ceb..ab6f54ab3b 100644 --- a/src/distilabel/llms/openai.py +++ b/src/distilabel/llms/openai.py @@ -688,6 +688,6 @@ def _name_for_openai_files(self, file_no: int) -> str: @staticmethod def _get_llm_statistics(completion: "OpenAIChatCompletion") -> "LLMStatistics": return { - "input_tokens": completion.usage.prompt_tokens if completion else 0, - "output_tokens": completion.usage.completion_tokens if completion else 0, + "input_tokens": [completion.usage.prompt_tokens if completion else 0], + "output_tokens": [completion.usage.completion_tokens if completion else 0], } diff --git a/src/distilabel/llms/utils.py b/src/distilabel/llms/utils.py index b2cf307ee8..2661e7e207 100644 --- a/src/distilabel/llms/utils.py +++ b/src/distilabel/llms/utils.py @@ -60,6 +60,6 @@ def prepare_output( "generations": generations, "statistics": { "input_tokens": input_tokens or 0, - "output_tokens": input_tokens or 0, + "output_tokens": output_tokens or 0, }, } diff --git a/src/distilabel/llms/vertexai.py b/src/distilabel/llms/vertexai.py index 7e85fbb079..3a9a89ea94 100644 --- a/src/distilabel/llms/vertexai.py +++ b/src/distilabel/llms/vertexai.py @@ -18,11 +18,14 @@ from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput +from distilabel.llms.utils import prepare_output from distilabel.steps.tasks.typing import StandardInput if TYPE_CHECKING: from vertexai.generative_models import Content, GenerationResponse, GenerativeModel + from distilabel.llms.typing import LLMStatistics + class VertexAILLM(AsyncLLM): """VertexAI LLM implementation running the async API clients for Gemini. @@ -160,24 +163,20 @@ async def agenerate( # type: ignore ) text = None - input_tokens = 0 - output_tokens = 0 try: text = content.candidates[0].text - input_tokens = content.usage_metadata.prompt_token_count - output_tokens = content.usage_metadata.candidates_token_count except ValueError: self._logger.warning( # type: ignore f"Received no response using VertexAI client (model: '{self.model}')." f" Finish reason was: '{content.candidates[0].finish_reason}'." ) + return prepare_output([text], **self._get_llm_statistics(content)) + @staticmethod + def _get_llm_statistics(content: "GenerationResponse") -> "LLMStatistics": return { - "generations": [text], - "statistics": { - "input_tokens": input_tokens, - "output_tokens": output_tokens, - }, + "input_tokens": [content.usage_metadata.prompt_token_count], + "output_tokens": [content.usage_metadata.candidates_token_count], } diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index c6aab8ede9..3d2762fb66 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -34,7 +34,7 @@ from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.openai import OpenAILLM from distilabel.llms.typing import GenerateOutput -from distilabel.llms.utils import compute_tokens +from distilabel.llms.utils import compute_tokens, prepare_output from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType @@ -42,9 +42,11 @@ from openai import OpenAI # noqa from transformers import PreTrainedTokenizer from vllm import LLM as _vLLM - from vllm.outputs import RequestOutputs + from vllm.outputs import RequestOutputs, CompletionOutput from distilabel.steps.tasks.typing import StandardInput + from distilabel.llms.typing import LLMStatistics + LogitsProcessorFn = Union[ Callable[[List[int], Any], Any], @@ -397,19 +399,13 @@ def generate( # type: ignore ] for input, outputs in zip(prepared_inputs, batch_outputs): generations.append( - { - "generations": [output.text for output in outputs.outputs], - "statistics": { - "input_tokens": [ - compute_tokens(input, self._tokenizer.encode) - ], - "output_tokens": [ - len(output.token_ids) for output in outputs.outputs - ], - }, - } + prepare_output( + [output.text for output in outputs.outputs], + **self._get_llm_statistics(input, outputs), + ) ) + # TODO: This must be updated for with the statistics # If logits_processor is set, we need to sort the outputs back to the original order # (would be needed only if we have multiple structured outputs in the dataset) if sorted_indices is not None: @@ -439,6 +435,14 @@ def _prepare_structured_output( self.structured_output["schema"] = schema return result["processor"] + def _get_llm_statistics( + self, input: "FormattedInput", outputs: "CompletionOutput" + ) -> "LLMStatistics": + return { + "input_tokens": [compute_tokens(input, self._tokenizer.encode)], + "output_tokens": [len(output.token_ids) for output in outputs.outputs], + } + class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin): """A client for the `vLLM` server implementing the OpenAI API specification. @@ -622,17 +626,8 @@ async def agenerate( # type: ignore f" Finish reason was: {choice.finish_reason}" ) generations.append(text) - return { - "generations": generations, - "statistics": { - "input_tokens": completion.usage.prompt_tokens - if completion.usage - else 0, - "output_tokens": completion.usage.completion_tokens - if completion.usage - else 0, - }, - } + + return prepare_output(generations, **self._get_llm_statistics(completion)) def _sort_batches( diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 8887789d4b..7076b50af0 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -149,8 +149,8 @@ async def test_agenerate_with_text_generation( assert result == { "generations": ["Aenean hendrerit aliquam velit..."], "statistics": { - "input_tokens": 0, - "output_tokens": 66, + "input_tokens": [0], + "output_tokens": [66], }, } @@ -196,8 +196,8 @@ async def test_agenerate_with_chat_completion( assert result == { "generations": [" Aenean hendrerit aliquam velit. ..."], "statistics": { - "input_tokens": 18, - "output_tokens": 66, + "input_tokens": [18], + "output_tokens": [66], }, } @@ -243,8 +243,8 @@ async def test_agenerate_with_chat_completion_fails( assert result == { "generations": [None], "statistics": { - "input_tokens": 18, - "output_tokens": 66, + "input_tokens": [18], + "output_tokens": [66], }, } @@ -294,8 +294,8 @@ async def test_generate(self, mock_inference_client: MagicMock) -> None: { "generations": [None], "statistics": { - "input_tokens": 18, - "output_tokens": 66, + "input_tokens": [18], + "output_tokens": [66], }, } ] @@ -334,8 +334,8 @@ async def test_agenerate_with_structured_output( assert result == { "generations": ["Aenean hendrerit aliquam velit..."], "statistics": { - "input_tokens": 0, - "output_tokens": 66, + "input_tokens": [0], + "output_tokens": [66], }, } diff --git a/tests/unit/llms/test_anthropic.py b/tests/unit/llms/test_anthropic.py index 9d165280f8..33e63527a9 100644 --- a/tests/unit/llms/test_anthropic.py +++ b/tests/unit/llms/test_anthropic.py @@ -55,7 +55,7 @@ async def test_agenerate(self, mock_anthropic: MagicMock) -> None: ) assert result == { "generations": ["Aenean hendrerit aliquam velit..."], - "statistics": {"input_tokens": 100, "output_tokens": 100}, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, } @pytest.mark.asyncio @@ -91,8 +91,8 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: assert generation == { "generations": [sample_user.model_dump_json()], "statistics": { - "input_tokens": 100, - "output_tokens": 100, + "input_tokens": [100], + "output_tokens": [100], }, } @@ -127,7 +127,7 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None: assert result == [ { "generations": ["Aenean hendrerit aliquam velit..."], - "statistics": {"input_tokens": 100, "output_tokens": 100}, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, } ] diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/llms/test_cohere.py index b88da01804..4ce26fbde4 100644 --- a/tests/unit/llms/test_cohere.py +++ b/tests/unit/llms/test_cohere.py @@ -67,11 +67,11 @@ async def test_agenerate(self, mock_async_client: mock.MagicMock) -> None: ) assert result == { "generations": ["Aenean hendrerit aliquam velit..."], - "statistics": {"input_tokens": 23, "output_tokens": 16}, + "statistics": {"input_tokens": [23], "output_tokens": [16]}, } @pytest.mark.skipif( - sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + sys.version_info < (3, 9), reason="`cohere` requires Python 3.9 or higher" ) @pytest.mark.asyncio async def test_agenerate_structured( @@ -103,7 +103,7 @@ async def test_agenerate_structured( ) assert generation == { "generations": [sample_user.model_dump_json()], - "statistics": {"input_tokens": 23, "output_tokens": 26}, + "statistics": {"input_tokens": [23], "output_tokens": [26]}, } @pytest.mark.asyncio @@ -131,7 +131,7 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None: assert result == [ { "generations": ["Aenean hendrerit aliquam velit..."], - "statistics": {"input_tokens": 23, "output_tokens": 16}, + "statistics": {"input_tokens": [23], "output_tokens": [16]}, } ] diff --git a/tests/unit/llms/test_groq.py b/tests/unit/llms/test_groq.py index 534214a5ef..1cbc04cf05 100644 --- a/tests/unit/llms/test_groq.py +++ b/tests/unit/llms/test_groq.py @@ -55,11 +55,11 @@ async def test_agenerate(self, mock_groq: MagicMock) -> None: ] ) == { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 100, "output_tokens": 100}, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, } @pytest.mark.skipif( - sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" + sys.version_info < (3, 9), reason="`groq` requires Python 3.9 or higher" ) @pytest.mark.asyncio async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: @@ -93,7 +93,7 @@ async def test_agenerate_structured(self, mock_openai: MagicMock) -> None: ) assert generation == { "generations": [sample_user.model_dump_json()], - "statistics": {"input_tokens": 100, "output_tokens": 100}, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, } @pytest.mark.asyncio @@ -122,7 +122,7 @@ async def test_generate(self, mock_groq: MagicMock) -> None: ) == [ { "generations": ["Aenean hendrerit aliquam velit..."], - "statistics": {"input_tokens": 100, "output_tokens": 100}, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, } ] diff --git a/tests/unit/llms/test_litellm.py b/tests/unit/llms/test_litellm.py index d0a53f66ae..c976ba789a 100644 --- a/tests/unit/llms/test_litellm.py +++ b/tests/unit/llms/test_litellm.py @@ -53,7 +53,7 @@ async def test_agenerate(self, mock_litellm: MagicMock, model: str) -> None: ) assert result == { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 21, "output_tokens": 11}, + "statistics": {"input_tokens": [21], "output_tokens": [11]}, } @pytest.mark.asyncio @@ -82,7 +82,7 @@ async def test_generate(self, mock_litellm: MagicMock, model: str) -> None: assert result == [ { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 21, "output_tokens": 11}, + "statistics": {"input_tokens": [21], "output_tokens": [11]}, } ] diff --git a/tests/unit/llms/test_llamacpp.py b/tests/unit/llms/test_llamacpp.py index e3dec0f04c..65b8147da4 100644 --- a/tests/unit/llms/test_llamacpp.py +++ b/tests/unit/llms/test_llamacpp.py @@ -54,7 +54,6 @@ def test_generate(self, llm: LlamaCppLLM) -> None: ], num_generations=3, ) - print("RESPONSE", responses) assert len(responses) == 2 generations = responses[0]["generations"] statistics = responses[0]["statistics"] diff --git a/tests/unit/llms/test_mistral.py b/tests/unit/llms/test_mistral.py index a613f9b49c..1a9b170fa7 100644 --- a/tests/unit/llms/test_mistral.py +++ b/tests/unit/llms/test_mistral.py @@ -62,7 +62,7 @@ async def test_agenerate(self, mock_mistral: MagicMock) -> None: ) assert result == { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 10, "output_tokens": 10}, + "statistics": {"input_tokens": [10], "output_tokens": [10]}, } @pytest.mark.asyncio @@ -101,8 +101,8 @@ async def test_agenerate_structured(self, mock_mistral: MagicMock) -> None: assert generation == { "generations": [sample_user.model_dump_json()], "statistics": { - "input_tokens": 100, - "output_tokens": 100, + "input_tokens": [100], + "output_tokens": [100], }, } @@ -137,7 +137,7 @@ async def test_generate(self, mock_mistral: MagicMock) -> None: assert result == [ { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 10, "output_tokens": 10}, + "statistics": {"input_tokens": [10], "output_tokens": [10]}, } ] diff --git a/tests/unit/llms/test_ollama.py b/tests/unit/llms/test_ollama.py index 34f32b82a4..fcb0c0a612 100644 --- a/tests/unit/llms/test_ollama.py +++ b/tests/unit/llms/test_ollama.py @@ -50,7 +50,7 @@ async def test_agenerate(self, mock_ollama: MagicMock) -> None: ) assert result == { "generations": ["Aenean hendrerit aliquam velit..."], - "statistics": {"input_tokens": 10, "output_tokens": 10}, + "statistics": {"input_tokens": [10], "output_tokens": [10]}, } @pytest.mark.asyncio @@ -82,7 +82,7 @@ async def test_generate(self, mock_ollama: MagicMock) -> None: assert result == [ { "generations": ["Aenean hendrerit aliquam velit..."], - "statistics": {"input_tokens": 10, "output_tokens": 10}, + "statistics": {"input_tokens": [10], "output_tokens": [10]}, } ] diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py index 463d850f77..9bed215030 100644 --- a/tests/unit/llms/test_openai.py +++ b/tests/unit/llms/test_openai.py @@ -83,7 +83,7 @@ async def test_agenerate( ) assert result == { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 100, "output_tokens": 100}, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, } @pytest.mark.asyncio @@ -124,7 +124,7 @@ async def test_agenerate_structured( ) assert generation == { "generations": [sample_user.model_dump_json()], - "statistics": {"input_tokens": 100, "output_tokens": 100}, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, } @pytest.mark.skipif( @@ -161,7 +161,7 @@ async def test_generate( assert result == [ { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 100, "output_tokens": 100}, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, } ] @@ -278,15 +278,15 @@ def test_check_and_get_batch_results( { "generations": ["output 1"], "statistics": { - "input_tokens": 100, - "output_tokens": 100, + "input_tokens": [100], + "output_tokens": [100], }, }, { "generations": ["output 2"], "statistics": { - "input_tokens": 100, - "output_tokens": 100, + "input_tokens": [100], + "output_tokens": [100], }, }, ] @@ -388,8 +388,8 @@ def test_parse_output( assert result == { "generations": [" Aenean hendrerit aliquam velit. ..."], "statistics": { - "input_tokens": 100, - "output_tokens": 100, + "input_tokens": [100], + "output_tokens": [100], }, } diff --git a/tests/unit/llms/test_vertexai.py b/tests/unit/llms/test_vertexai.py index c2bb14c595..35059e7eb6 100644 --- a/tests/unit/llms/test_vertexai.py +++ b/tests/unit/llms/test_vertexai.py @@ -70,7 +70,7 @@ async def test_agenerate(self, mock_generative_model: MagicMock) -> None: ) assert result == { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 10, "output_tokens": 10}, + "statistics": {"input_tokens": [10], "output_tokens": [10]}, } @pytest.mark.asyncio @@ -118,7 +118,7 @@ async def test_generate(self, mock_generative_model: MagicMock) -> None: assert result == [ { "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": 10, "output_tokens": 10}, + "statistics": {"input_tokens": [10], "output_tokens": [10]}, } ] From 8c35af571f42bd22031ef6ac93e358300a9d90bb Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 17 Oct 2024 15:29:13 +0200 Subject: [PATCH 13/35] Fix async llms not returning properly the generations grouped by num_generations --- src/distilabel/llms/base.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index 3001a477e0..cabfd37068 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -33,7 +33,6 @@ RuntimeParametersMixin, ) from distilabel.utils.docstring import parse_google_docstring -from distilabel.utils.itertools import grouper from distilabel.utils.notebook import in_notebook from distilabel.utils.serialization import _Serializable @@ -460,21 +459,15 @@ async def _agenerate( for input in inputs ] result = await asyncio.gather(*tasks) - # TODO: Update the object returned to be the same as in synchronous LLMs with batches. + return merge_responses(result) - return result - - # TODO: Update the object returned to be the same as in synchronous LLMs with batches. tasks = [ asyncio.create_task(self.agenerate(input=input, **kwargs)) for input in inputs for _ in range(num_generations) ] outputs = await asyncio.gather(*tasks) - return [ - list(group)[0] - for group in grouper(outputs, n=num_generations, incomplete="ignore") - ] + return merge_responses(outputs) def generate( self, @@ -594,3 +587,28 @@ def _prepare_kwargs( }, ) return arguments + + +def merge_responses(responses: List[Dict[str, Any]]) -> List["GenerateOutput"]: + """Helper function to group the responses from `LLM.agenerate` method according + to the number of generations requested. + + Args: + responses: the responses from the `LLM.agenerate` method. + + Returns: + Merges the texts and statistics of the responses into a single response. + """ + if not responses: + return [] + + first = responses[0] + return [ + { + "generations": sum((r["generations"] for r in responses), []), + "statistics": { + key: sum((r["statistics"][key] for r in responses), []) + for key in first["statistics"] + }, + } + ] From 6d19de7c47bf530ce54bddcb166751e5315d41f4 Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 17 Oct 2024 15:53:08 +0200 Subject: [PATCH 14/35] Fix async llms not processing multiple generations --- src/distilabel/llms/litellm.py | 4 +- src/distilabel/llms/llamacpp.py | 3 +- src/distilabel/llms/vllm.py | 8 +-- .../huggingface/test_inference_endpoints.py | 54 +++++++++++++------ tests/unit/llms/test_openai.py | 43 +++++++++++---- 5 files changed, 81 insertions(+), 31 deletions(-) diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/llms/litellm.py index 384837a8a9..aa122078e1 100644 --- a/src/distilabel/llms/litellm.py +++ b/src/distilabel/llms/litellm.py @@ -259,7 +259,9 @@ async def _call_aclient_until_n_choices() -> List["Choices"]: raise e generations = [] - input_tokens = [token_counter(model=self.model, messages=input)] + input_tokens = [ + token_counter(model=self.model, messages=input) + ] * num_generations output_tokens = [] if self.structured_output: diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/llms/llamacpp.py index b7e83dc126..fcfc9c0f37 100644 --- a/src/distilabel/llms/llamacpp.py +++ b/src/distilabel/llms/llamacpp.py @@ -247,7 +247,8 @@ def generate( # type: ignore batch_outputs.append( prepare_output( outputs, - input_tokens=[chat_completions["usage"]["prompt_tokens"]], + input_tokens=[chat_completions["usage"]["prompt_tokens"]] + * num_generations, output_tokens=output_tokens, ) ) diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index 3d2762fb66..425164fb63 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -405,7 +405,7 @@ def generate( # type: ignore ) ) - # TODO: This must be updated for with the statistics + # TODO: This must be updated with the statistics # If logits_processor is set, we need to sort the outputs back to the original order # (would be needed only if we have multiple structured outputs in the dataset) if sorted_indices is not None: @@ -438,9 +438,11 @@ def _prepare_structured_output( def _get_llm_statistics( self, input: "FormattedInput", outputs: "CompletionOutput" ) -> "LLMStatistics": + output_tokens = [len(output.token_ids) for output in outputs.outputs] return { - "input_tokens": [compute_tokens(input, self._tokenizer.encode)], - "output_tokens": [len(output.token_ids) for output in outputs.outputs], + "input_tokens": [compute_tokens(input, self._tokenizer.encode)] + * len(output_tokens), + "output_tokens": output_tokens, } diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/llms/huggingface/test_inference_endpoints.py index 7076b50af0..72e7599e0d 100644 --- a/tests/unit/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/llms/huggingface/test_inference_endpoints.py @@ -14,7 +14,7 @@ import os import random -from typing import Generator +from typing import Any, Dict, Generator, List from unittest import mock from unittest.mock import AsyncMock, MagicMock, patch @@ -248,11 +248,41 @@ async def test_agenerate_with_chat_completion_fails( }, } + @pytest.mark.parametrize( + "num_generations, expected_result", + [ + ( + 1, + [ + { + "generations": ["text"], + "statistics": {"input_tokens": [18], "output_tokens": [66]}, + } + ], + ), + ( + 2, + [ + { + "generations": ["text"] * 2, + "statistics": { + "input_tokens": [18, 18], + "output_tokens": [66, 66], + }, + } + ], + ), + ], + ) @pytest.mark.asyncio - async def test_generate(self, mock_inference_client: MagicMock) -> None: + async def test_generate( + self, + mock_inference_client: MagicMock, + num_generations: int, + expected_result: List[Dict[str, Any]], + ) -> None: llm = InferenceEndpointsLLM( model_id="distilabel-internal-testing/tiny-random-mistral", - # tokenizer_id="distilabel-internal-testing/tiny-random-mistral", ) llm.load() @@ -264,10 +294,11 @@ async def test_generate(self, mock_inference_client: MagicMock) -> None: index=0, message=ChatCompletionOutputMessage( role="assistant", - content=None, + content="text", ), ) - ], + ] + * num_generations, created=1721045246, id="", model="meta-llama/Meta-Llama-3-70B-Instruct", @@ -288,17 +319,10 @@ async def test_generate(self, mock_inference_client: MagicMock) -> None: "content": "Lorem ipsum dolor sit amet, consectetur adipiscing elit.", }, ] - ] + ], + num_generations=num_generations, ) - assert result == [ - { - "generations": [None], - "statistics": { - "input_tokens": [18], - "output_tokens": [66], - }, - } - ] + assert result == expected_result @pytest.mark.asyncio async def test_agenerate_with_structured_output( diff --git a/tests/unit/llms/test_openai.py b/tests/unit/llms/test_openai.py index 9bed215030..dfd8925f58 100644 --- a/tests/unit/llms/test_openai.py +++ b/tests/unit/llms/test_openai.py @@ -15,7 +15,7 @@ import os import sys from textwrap import dedent -from typing import Any, Dict +from typing import Any, Dict, List from unittest import mock from unittest.mock import AsyncMock, MagicMock, Mock, patch @@ -130,17 +130,43 @@ async def test_agenerate_structured( @pytest.mark.skipif( sys.version_info < (3, 9), reason="`mistralai` requires Python 3.9 or higher" ) + @pytest.mark.parametrize( + "num_generations, expected_result", + [ + ( + 1, + [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."], + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } + ], + ), + ( + 2, + [ + { + "generations": [" Aenean hendrerit aliquam velit. ..."] * 2, + "statistics": {"input_tokens": [100], "output_tokens": [100]}, + } + ], + ), + ], + ) @pytest.mark.asyncio async def test_generate( - self, async_openai_mock: MagicMock, _openai_mock: MagicMock + self, + async_openai_mock: MagicMock, + _openai_mock: MagicMock, + num_generations: int, + expected_result: List[Dict[str, Any]], ) -> None: llm = OpenAILLM(model=self.model_id, api_key="api.key") # type: ignore llm._aclient = async_openai_mock mocked_completion = Mock( - choices=[ - Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ...")) - ], + choices=[Mock(message=Mock(content=" Aenean hendrerit aliquam velit. ..."))] + * num_generations, usage=Mock(prompt_tokens=100, completion_tokens=100), ) llm._aclient.chat.completions.create = AsyncMock(return_value=mocked_completion) @@ -158,12 +184,7 @@ async def test_generate( ] ] ) - assert result == [ - { - "generations": [" Aenean hendrerit aliquam velit. ..."], - "statistics": {"input_tokens": [100], "output_tokens": [100]}, - } - ] + assert result == expected_result with pytest.raises(ValueError): llm.generate( From 9746d75a363ed5e66d66bc60a55542b4ef27761c Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 18 Oct 2024 12:26:54 +0200 Subject: [PATCH 15/35] Fix vllm sorting mechanism and add mocked generate method to the test suite --- src/distilabel/llms/vllm.py | 78 ++++++++++++++++--- tests/unit/llms/test_vllm.py | 140 +++++++++++++++++++---------------- 2 files changed, 144 insertions(+), 74 deletions(-) diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index 425164fb63..901e70bb85 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -235,7 +235,7 @@ def prepare_input(self, input: "StandardInput") -> str: The prompt to send to the LLM. """ if self._tokenizer.chat_template is None: - return input[0]["content"] + return [item["content"] for item in input if item["role"] == "user"][0] prompt: str = ( self._tokenizer.apply_chat_template( @@ -271,7 +271,14 @@ def _prepare_batches( batches = {} for i, (instruction, structured_output) in enumerate(inputs): instruction = self.prepare_input(instruction) - instruction_order[instruction] = i + + # We need to convert the instruction to a string to make it hashable + str_instruction = instruction + if not isinstance(instruction, str): + str_instruction = json.dumps(instruction) + + instruction_order[str_instruction] = i + structured_output = json.dumps(structured_output) if structured_output not in batches: batches[structured_output] = [instruction] @@ -284,7 +291,7 @@ def _prepare_batches( ] # Generate the list of indices based on the original order sorted_indices = [ - instruction_order[instruction] for instruction in flat_instructions + instruction_order[str_instruction] for instruction in flat_instructions ] return [ (batch, json.loads(schema)) for schema, batch in batches.items() @@ -357,7 +364,6 @@ def generate( # type: ignore # Simulate a batch without the structured output content prepared_batches = [([self.prepare_input(input) for input in inputs], None)] sorted_indices = None - # Case in which we have a single structured output for the dataset if self._structured_output_logits_processor: logits_processors.append(self._structured_output_logits_processor) @@ -388,12 +394,15 @@ def generate( # type: ignore **extra_sampling_params, ) - batch_outputs: "RequestOutputs" = self._model.generate( + batch_outputs: List["RequestOutputs"] = self._model.generate( prepared_inputs, sampling_params, use_tqdm=False, # type: ignore ) + # TODO: This is repeated in prepare_output, but for simplicity we extract + # the batched_outputs as we did when there wasn't statistics and we just + # return the str generations batched_outputs += [ [output.text for output in outputs.outputs] for outputs in batch_outputs ] @@ -405,14 +414,16 @@ def generate( # type: ignore ) ) - # TODO: This must be updated with the statistics # If logits_processor is set, we need to sort the outputs back to the original order # (would be needed only if we have multiple structured outputs in the dataset) if sorted_indices is not None: - batched_outputs = _sort_batches( - batched_outputs, sorted_indices, num_generations=num_generations + # Sort the batched outputs together with the statistics + generations = self._prepare_sorted_resuts( + batched_outputs, + sorted_indices, + generations, + num_generations=num_generations, ) - # return batched_outputs return generations def _prepare_structured_output( @@ -445,6 +456,55 @@ def _get_llm_statistics( "output_tokens": output_tokens, } + @staticmethod + def _prepare_sorted_resuts( + batched_outputs: List[List[FormattedInput]], + sorted_indices: List[int], + generations: List[GenerateOutput], + num_generations: int = 1, + ) -> List[GenerateOutput]: + """Helper method to sort the results in case of multiple structured outputs in the dataset. + + Args: + batched_outputs: The mini-batches generated by the model. + sorted_indices: The indices that would sort the mini-batches back to the original order. + generations: The prepared outputs that would be returned in the general case, + from which the statistics will be extracted and sorted. + num_generations: The number of generations requested to vLLM. Defaults to 1. + + Returns: + The list of GenerateOutput sorted back to the original order. + """ + + # This was the only required sort back with only the generations + batched_outputs = _sort_batches( + batched_outputs, sorted_indices, num_generations=num_generations + ) + # Prepare the statistics to be sorted + # Loop over all the variables in the statistics + # Get the keys from the LLMStatistics + statistic_fields = list(generations[0]["statistics"].keys()) + statistics = {} + for field in statistic_fields: + batched_field = _sort_batches( + [g["statistics"][field] for g in generations], + sorted_indices, + num_generations=num_generations, + ) + statistics[field] = batched_field + + # Regenerates the outputs as they are returned buy `preare_output` + sorted_results = [] + for i, batched_output in enumerate(batched_outputs): + generation = {"generations": batched_output} + statistics = { + field: batched_field[i] for field, batched_field in statistics.items() + } + generation.update({"statistics": statistics}) + sorted_results.append(generation) + + return sorted_results + class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin): """A client for the `vLLM` server implementing the OpenAI API specification. diff --git a/tests/unit/llms/test_vllm.py b/tests/unit/llms/test_vllm.py index 7f29790873..96714e1ba3 100644 --- a/tests/unit/llms/test_vllm.py +++ b/tests/unit/llms/test_vllm.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List +from typing import Any, Dict, List from unittest import mock -import numpy as np import pytest from openai.pagination import SyncPage from openai.types import Model @@ -25,7 +24,7 @@ from pydantic import BaseModel from distilabel.llms import vLLM -from distilabel.llms.vllm import ClientvLLM, _sort_batches +from distilabel.llms.vllm import ClientvLLM class Character(BaseModel): @@ -104,7 +103,8 @@ class Animal(BaseModel): # Just a mock to avoid loading the model class DummyTokenizer: - chat_template = None + # chat_template = None + chat_template = "template" def __init__(self) -> None: pass @@ -112,83 +112,93 @@ def __init__(self) -> None: def apply_chat_template(self, input, **kwargs): return input + def encode(self, text: str): + return [1, 2, 3, 4, 5] + class TestvLLM: + @pytest.mark.parametrize("multi_structured_output", (False, True)) @pytest.mark.parametrize( - "num_generations, expected_sorted_batches", + "num_generations, expected_result", [ ( 1, [ - "Generate a character from a RPG game.", - "Generate an animal from a zoo.", - "Repeated character", - "What's the weather like today in Seattle in Celsius degrees?", - "Other character", - "repeated regex", + { + "generations": ["I'm fine thank you"], + "statistics": {"input_tokens": [5], "output_tokens": [6]}, + } ], ), ( - 3, - np.repeat( - [ - "Generate a character from a RPG game.", - "Generate an animal from a zoo.", - "Repeated character", - "What's the weather like today in Seattle in Celsius degrees?", - "Other character", - "repeated regex", - ], - 3, - ).tolist(), + 2, + [ + { + "generations": ["I'm fine thank you"] * 2, + "statistics": {"input_tokens": [5, 5], "output_tokens": [6, 6]}, + } + ], ), ], ) - def test_prepare_batches_and_sort_back( - self, num_generations: int, expected_sorted_batches: List[str] - ): - formatted_inputs = [ - (item["instruction"], item["structured_output"]) - for row in SAMPLE_DATA - for item in row - ] + def test_generate( + self, + multi_structured_output: bool, + num_generations: int, + expected_result: List[Dict[str, Any]], + ) -> None: llm = vLLM(model="dummy") llm._tokenizer = DummyTokenizer() - batches, indices = llm._prepare_batches(formatted_inputs) - # NOTE: We have to simulate calling self._model.generate(n=num_generations) and then sorting the results - num_generations_batches = [] - for batch in batches: - num_generations_batches.append( - (np.repeat(batch[0], num_generations).tolist(), batch[1]) + vllm_mock = mock.MagicMock() + # mock the import by hacking sys.modules + # https://stackoverflow.com/questions/60919705/how-to-mock-in-a-python-unittest-a-library-not-installed-locally + import sys + + if "vllm" not in sys.modules: + sys.modules["vllm"] = vllm_mock + llm._model = vllm_mock + + mocked_requests_output = [ + mock.Mock( # RequestOutput + outputs=[ + mock.Mock( # CompletionOutput + text="I'm fine thank you", + token_ids=[1, 2, 3, 4, 5, 7], + ) + ] + * num_generations, ) - batches = num_generations_batches - # Recreate as the output from batched_outputs += [[output.text for output in outputs.outputs] for outputs in batch_outputs] - batches = [batch for batch, _ in batches] - sorted_batches = _sort_batches( - batches, indices, num_generations=num_generations - ) + ] - assert sorted_batches == [ - np.repeat( - [ - "Generate a character from a RPG game.", - "Generate an animal from a zoo.", - "Repeated character", - ], - num_generations, - ).tolist(), - np.repeat( - ["What's the weather like today in Seattle in Celsius degrees?"], - num_generations, - ).tolist(), - np.repeat( + llm._model.generate = mock.MagicMock(return_value=mocked_requests_output) + if not multi_structured_output: + formatted_inputs = [ [ - "Other character", - "repeated regex", - ], - num_generations, - ).tolist(), - ] + {"role": "system", "content": "sysprompt"}, + { + "role": "user", + "content": "I'm fine thank you", + }, + ] + ] + else: + formatted_inputs = [ + ( + [ + {"role": "system", "content": "sysprompt"}, + { + "role": "user", + "content": "I'm fine thank you", + }, + ], + { + "format": "json", + "schema": Character.model_json_schema(), + }, + ) + ] + result = llm.generate(inputs=formatted_inputs, num_generations=num_generations) + assert result == expected_result @mock.patch("openai.OpenAI") @@ -256,7 +266,7 @@ async def test_agenerate( assert generations == { "generations": ["I'm fine thank you", "I'm fine thank you sir"], "statistics": { - "input_tokens": 10, - "output_tokens": 10, + "input_tokens": [10], + "output_tokens": [10], }, } From f108670a9707a343184340ead79232e263d71e58 Mon Sep 17 00:00:00 2001 From: plaguss Date: Tue, 22 Oct 2024 11:31:06 +0200 Subject: [PATCH 16/35] Checkpoint --- src/distilabel/llms/base.py | 6 + src/distilabel/steps/tasks/base.py | 94 ++++++- tests/unit/conftest.py | 40 ++- tests/unit/steps/tasks/test_base.py | 418 +++++++++++++--------------- 4 files changed, 306 insertions(+), 252 deletions(-) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index cabfd37068..dcdc346a60 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -459,6 +459,12 @@ async def _agenerate( for input in inputs ] result = await asyncio.gather(*tasks) + print("\n_agenerate\n\n", result) + print("\n_agenerate MERGED\n\n", merge_responses(result)) + print( + "CORRECT merge_response, ITS GROUPING num_generations MIXED WITH THE INPUTS PASSED" + ) + # TODO: Update this, return merge_responses(result) tasks = [ diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index d27a3b80f5..f8f33df14c 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -34,7 +34,7 @@ from distilabel.utils.dicts import group_dicts if TYPE_CHECKING: - from distilabel.llms.typing import GenerateOutput, LLMOutput, LLMStatistics + from distilabel.llms.typing import GenerateOutput, LLMStatistics from distilabel.steps.tasks.typing import ChatType, FormattedInput from distilabel.steps.typing import StepOutput @@ -170,30 +170,40 @@ def _format_outputs( A list containing a dictionary with the outputs of the task for each input. """ inputs = [None] if input is None else [input] - + print("INPUTS", inputs) formatted_outputs = [] - for output, input in zip(outputs, inputs * len(outputs)): # type: ignore + repeate_inputs = len(outputs.get("generations")) + outputs = normalize_statistics(outputs) + + for (output, stats), input in zip( + iterate_generations_with_stats(outputs), inputs * repeate_inputs + ): # type: ignore + # for output, input in zip(outputs, inputs * len(outputs)): # type: ignore try: # Extract the generations, and move the statistics to the distilabel_metadata, # to keep everything clean - output_generations: "LLMOutput" = output.get("generations", []) - formatted_output = self.format_output(output_generations, input) + # TODO: THIS WOULD FAIL IF THE LLM DOESN'T RETURN generations, + # WE HAVE TO REMOVE THE STATISTICS AND PASS EVERYTHING ELSE + print("OUTPUT", output) + print("STATS", stats) + print("INPUT", input) + # output_generations: "LLMOutput" = output.get("generations", []) + formatted_output = self.format_output(output, input) formatted_output = self._create_metadata( formatted_output, - output_generations, + output, input, add_raw_output=self.add_raw_output, # type: ignore add_raw_input=self.add_raw_input, # type: ignore - statistics=output.get("statistics"), + # statistics=output.get("statistics"), + statistics=stats, ) formatted_outputs.append(formatted_output) except Exception as e: self._logger.warning( # type: ignore f"Task '{self.name}' failed to format output: {e}. Saving raw response." # type: ignore ) - formatted_outputs.append( - self._output_on_failure(output.get("generations", []), input) - ) + formatted_outputs.append(self._output_on_failure(output, input)) return formatted_outputs def _output_on_failure( @@ -437,6 +447,8 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore ) task_outputs = [] + print("INPUTS", inputs) + print("OUTPUTS", outputs) for input, input_outputs in zip(inputs, outputs): formatted_outputs = self._format_outputs(input_outputs, input) @@ -449,6 +461,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore # Create a row per generation for formatted_output in formatted_outputs: + print("FORMATED", formatted_output) task_outputs.append( {**input, **formatted_output, "model_name": self.llm.model_name} ) @@ -477,3 +490,64 @@ class GlobalTask(_Task, GlobalStep): """ pass + + +def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": + """Transforms the GenerateOutput statistics to have the same length as the generations. + + Args: + data: A generate output that possibly has different lengths of statistics + vs generations (due to num_generations=3 returning 3 generations, but + for example the tokens are only counted once). + + Returns: + Normalized statistics according to the generations length. + + Examples: + ```python + data = { + "generations": ["text1", "text2", "text3", "text4"], + "statistics": {"input_tokens": [1], "output_tokens": [1, 2, 3]} + } + normalize_statistics(data) + data = { + "generations": ["text1", "text2", "text3"], + "statistics": {"input_tokens": [1, 1, 1], "output_tokens": [1, 2, 3]} + } + ``` + """ + if not (statistics := output.get("statistics")): + print(statistics) + return output + gen_length = len(output["generations"]) + + for stat_key, stat_values in output["statistics"].items(): + current_length = len(stat_values) + + if current_length < gen_length: + # Calculate how many times to repeat the tokens + repeats = gen_length // current_length + remainder = gen_length % current_length + + # Create new list with repeated values + new_values = stat_values * repeats + stat_values[:remainder] + output["statistics"][stat_key] = new_values + + return output + + +def iterate_generations_with_stats(output: "GenerateOutput") -> "GenerateOutput": + """Helper function to iterate together generations and statistics while + processing them inside _format_outputs. + + Args: + output: Output from the LLM.generate_outputs method. + + Yields: + Iterator of generation and statistics paired. + """ + for i, generation in enumerate(output["generations"]): + # Create a new dictionary with the statistics for this index + stats = {key: values[i] for key, values in output["statistics"].items()} + + yield generation, stats diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 905b0f7231..2127e26a5f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -39,11 +39,20 @@ def model_name(self) -> str: async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": - # return ["output" for _ in range(num_generations)] - return [ - {"generations": "output", "statistics": {"test": "test"}} - for _ in range(num_generations) - ] + # return { + # "generations": ["output"], + # "statistics": { + # "input_tokens": [12], + # "output_tokens": [12], + # }, + # } + return { + "generations": ["output" for i in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } class DummyLLM(LLM): @@ -60,11 +69,14 @@ def generate( # type: ignore self, inputs: "FormattedInput", num_generations: int = 1 ) -> List["GenerateOutput"]: return [ - [ - {"generations": "output", "statistics": {"test": "test"}} - for _ in range(num_generations) - ] - ] + { + "generations": [f"output {i}" for i in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) class DummyMagpieLLM(LLM, MagpieChatTemplateMixin): @@ -80,7 +92,13 @@ def generate( ) -> List["GenerateOutput"]: return [ [ - {"generations": "output", "statistics": {"test": "test"}} + { + "generations": ["output"] * num_generations, + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } for _ in range(num_generations) ] for _ in range(len(inputs)) diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index ef400acfc9..87ce7198d3 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -109,7 +109,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -124,22 +124,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], - "statistics": {"test": "test"}, - }, - }, - { - "instruction": "test_0", - "additional_info": "additional_info_0", - "output": "output", - "info_from_input": "additional_info_0", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_0", "role": "user"}, - ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -154,7 +139,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -169,37 +154,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], - "statistics": {"test": "test"}, - }, - }, - { - "instruction": "test_1", - "additional_info": "additional_info_1", - "output": "output", - "info_from_input": "additional_info_1", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_1", "role": "user"}, - ], - "statistics": {"test": "test"}, - }, - }, - { - "instruction": "test_2", - "additional_info": "additional_info_2", - "output": "output", - "info_from_input": "additional_info_2", - "model_name": "test", - "distilabel_metadata": { - "raw_output_task": "output", - "raw_input_task": [ - {"content": "", "role": "system"}, - {"content": "test_2", "role": "user"}, - ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -214,7 +169,7 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, { @@ -229,186 +184,186 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, }, ], ), - ( - [ - {"instruction": "test_0", "additional_info": "additional_info_0"}, - {"instruction": "test_1", "additional_info": "additional_info_1"}, - {"instruction": "test_2", "additional_info": "additional_info_2"}, - ], - True, - [ - { - "instruction": "test_0", - "additional_info": "additional_info_0", - "output": ["output", "output", "output"], - "info_from_input": [ - "additional_info_0", - "additional_info_0", - "additional_info_0", - ], - "model_name": "test", - "distilabel_metadata": [ - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_0", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_0", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_0", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - ], - }, - { - "instruction": "test_1", - "additional_info": "additional_info_1", - "output": ["output", "output", "output"], - "info_from_input": [ - "additional_info_1", - "additional_info_1", - "additional_info_1", - ], - "model_name": "test", - "distilabel_metadata": [ - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_1", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_1", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_1", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - ], - }, - { - "instruction": "test_2", - "additional_info": "additional_info_2", - "output": ["output", "output", "output"], - "info_from_input": [ - "additional_info_2", - "additional_info_2", - "additional_info_2", - ], - "model_name": "test", - "distilabel_metadata": [ - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_2", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_2", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - { - "raw_output_task": "output", - "raw_input_task": [ - { - "content": "", - "role": "system", - }, - { - "content": "test_2", - "role": "user", - }, - ], - "statistics": {"test": "test"}, - }, - ], - }, - ], - ), + # ( + # [ + # {"instruction": "test_0", "additional_info": "additional_info_0"}, + # {"instruction": "test_1", "additional_info": "additional_info_1"}, + # {"instruction": "test_2", "additional_info": "additional_info_2"}, + # ], + # True, + # [ + # { + # "instruction": "test_0", + # "additional_info": "additional_info_0", + # "output": ["output", "output", "output"], + # "info_from_input": [ + # "additional_info_0", + # "additional_info_0", + # "additional_info_0", + # ], + # "model_name": "test", + # "distilabel_metadata": [ + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_0", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_0", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_0", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # ], + # }, + # { + # "instruction": "test_1", + # "additional_info": "additional_info_1", + # "output": ["output", "output", "output"], + # "info_from_input": [ + # "additional_info_1", + # "additional_info_1", + # "additional_info_1", + # ], + # "model_name": "test", + # "distilabel_metadata": [ + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_1", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_1", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_1", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # ], + # }, + # { + # "instruction": "test_2", + # "additional_info": "additional_info_2", + # "output": ["output", "output", "output"], + # "info_from_input": [ + # "additional_info_2", + # "additional_info_2", + # "additional_info_2", + # ], + # "model_name": "test", + # "distilabel_metadata": [ + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_2", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_2", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # { + # "raw_output_task": "output", + # "raw_input_task": [ + # { + # "content": "", + # "role": "system", + # }, + # { + # "content": "test_2", + # "role": "user", + # }, + # ], + # "statistics": {"input_tokens": 12, "output_tokens": 12}, + # }, + # ], + # }, + # ], + # ), ], ) def test_process( @@ -424,7 +379,7 @@ def test_process( llm=llm, pipeline=pipeline, group_generations=group_generations, - num_generations=3, + num_generations=2, ) task.load() result = next(task.process(input)) @@ -436,7 +391,7 @@ def test_process_overriding_inputs(self) -> None: name="task", llm=llm, group_generations=False, - num_generations=3, + num_generations=2, input_mappings={"instruction": "instruction_2"}, ) task.load() @@ -452,6 +407,7 @@ def test_process_overriding_inputs(self) -> None: ] ) ) + print("REUSLT", result) assert result == [ { @@ -468,7 +424,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -490,7 +446,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -512,7 +468,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"test": "test"}, + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", From c8063a45460fba007f525936d4c1fef69cbe2ff8 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 23 Oct 2024 09:44:23 +0200 Subject: [PATCH 17/35] Fix tests from merge responses and group generations --- src/distilabel/llms/base.py | 42 ++- src/distilabel/steps/tasks/base.py | 17 +- tests/unit/conftest.py | 12 +- tests/unit/steps/tasks/test_base.py | 537 +++++++++++++--------------- 4 files changed, 282 insertions(+), 326 deletions(-) diff --git a/src/distilabel/llms/base.py b/src/distilabel/llms/base.py index dcdc346a60..6df322e7cc 100644 --- a/src/distilabel/llms/base.py +++ b/src/distilabel/llms/base.py @@ -21,6 +21,7 @@ import time from abc import ABC, abstractmethod from functools import cached_property +from itertools import islice from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel, ConfigDict, Field, PrivateAttr @@ -459,13 +460,7 @@ async def _agenerate( for input in inputs ] result = await asyncio.gather(*tasks) - print("\n_agenerate\n\n", result) - print("\n_agenerate MERGED\n\n", merge_responses(result)) - print( - "CORRECT merge_response, ITS GROUPING num_generations MIXED WITH THE INPUTS PASSED" - ) - # TODO: Update this, - return merge_responses(result) + return result tasks = [ asyncio.create_task(self.agenerate(input=input, **kwargs)) @@ -473,7 +468,7 @@ async def _agenerate( for _ in range(num_generations) ] outputs = await asyncio.gather(*tasks) - return merge_responses(outputs) + return merge_responses(outputs, n=num_generations) def generate( self, @@ -595,26 +590,41 @@ def _prepare_kwargs( return arguments -def merge_responses(responses: List[Dict[str, Any]]) -> List["GenerateOutput"]: +def merge_responses( + responses: List[Dict[str, Any]], n: int = 1 +) -> List[Dict[str, Any]]: """Helper function to group the responses from `LLM.agenerate` method according to the number of generations requested. Args: responses: the responses from the `LLM.agenerate` method. + n: number of responses to group together. Defaults to 1. Returns: - Merges the texts and statistics of the responses into a single response. + List of merged responses, where each merged response contains n generations + and their corresponding statistics. """ if not responses: return [] - first = responses[0] - return [ - { - "generations": sum((r["generations"] for r in responses), []), + def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield list(islice(lst, i, i + n)) + + # Split responses into groups of size n + grouped_responses = list(chunks(responses, n)) + + result = [] + for group in grouped_responses: + first = group[0] + merged = { + "generations": sum((r["generations"] for r in group), []), "statistics": { - key: sum((r["statistics"][key] for r in responses), []) + key: sum((r["statistics"][key] for r in group), []) for key in first["statistics"] }, } - ] + result.append(merged) + + return result diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index f8f33df14c..d879d13536 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -170,7 +170,6 @@ def _format_outputs( A list containing a dictionary with the outputs of the task for each input. """ inputs = [None] if input is None else [input] - print("INPUTS", inputs) formatted_outputs = [] repeate_inputs = len(outputs.get("generations")) outputs = normalize_statistics(outputs) @@ -178,16 +177,9 @@ def _format_outputs( for (output, stats), input in zip( iterate_generations_with_stats(outputs), inputs * repeate_inputs ): # type: ignore - # for output, input in zip(outputs, inputs * len(outputs)): # type: ignore try: # Extract the generations, and move the statistics to the distilabel_metadata, # to keep everything clean - # TODO: THIS WOULD FAIL IF THE LLM DOESN'T RETURN generations, - # WE HAVE TO REMOVE THE STATISTICS AND PASS EVERYTHING ELSE - print("OUTPUT", output) - print("STATS", stats) - print("INPUT", input) - # output_generations: "LLMOutput" = output.get("generations", []) formatted_output = self.format_output(output, input) formatted_output = self._create_metadata( formatted_output, @@ -195,7 +187,6 @@ def _format_outputs( input, add_raw_output=self.add_raw_output, # type: ignore add_raw_input=self.add_raw_input, # type: ignore - # statistics=output.get("statistics"), statistics=stats, ) formatted_outputs.append(formatted_output) @@ -224,7 +215,6 @@ def _output_on_failure( ) return outputs - # TODO: Rename to _create_metadata def _create_metadata( self, output: Dict[str, Any], @@ -447,8 +437,6 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore ) task_outputs = [] - print("INPUTS", inputs) - print("OUTPUTS", outputs) for input, input_outputs in zip(inputs, outputs): formatted_outputs = self._format_outputs(input_outputs, input) @@ -461,7 +449,6 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore # Create a row per generation for formatted_output in formatted_outputs: - print("FORMATED", formatted_output) task_outputs.append( {**input, **formatted_output, "model_name": self.llm.model_name} ) @@ -516,8 +503,8 @@ def normalize_statistics(output: "GenerateOutput") -> "GenerateOutput": } ``` """ - if not (statistics := output.get("statistics")): - print(statistics) + statistics = output.get("statistics") + if not statistics: return output gen_length = len(output["generations"]) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2127e26a5f..87343cf7bf 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -15,6 +15,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union import pytest +from pydantic import PrivateAttr from distilabel.llms.base import LLM, AsyncLLM from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin @@ -28,9 +29,11 @@ # Defined here too, so that the serde still works class DummyAsyncLLM(AsyncLLM): structured_output: Any = None + n_generations_supported: bool = True # To work as OpenAI or an LLM that doesn't allow num_generations out of the box + _num_generations_param_supported: bool = PrivateAttr(default=True) def load(self) -> None: - pass + self._num_generations_param_supported = self.n_generations_supported @property def model_name(self) -> str: @@ -39,13 +42,6 @@ def model_name(self) -> str: async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": - # return { - # "generations": ["output"], - # "statistics": { - # "input_tokens": [12], - # "output_tokens": [12], - # }, - # } return { "generations": ["output" for i in range(num_generations)], "statistics": { diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index 87ce7198d3..d0772eb850 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -86,6 +86,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: ): Task(name="task", llm=DummyAsyncLLM()) # type: ignore + @pytest.mark.parametrize( + "n_generations_supported", + [True, False], + ) @pytest.mark.parametrize( "input, group_generations, expected", [ @@ -189,181 +193,136 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: }, ], ), - # ( - # [ - # {"instruction": "test_0", "additional_info": "additional_info_0"}, - # {"instruction": "test_1", "additional_info": "additional_info_1"}, - # {"instruction": "test_2", "additional_info": "additional_info_2"}, - # ], - # True, - # [ - # { - # "instruction": "test_0", - # "additional_info": "additional_info_0", - # "output": ["output", "output", "output"], - # "info_from_input": [ - # "additional_info_0", - # "additional_info_0", - # "additional_info_0", - # ], - # "model_name": "test", - # "distilabel_metadata": [ - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_0", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_0", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_0", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # ], - # }, - # { - # "instruction": "test_1", - # "additional_info": "additional_info_1", - # "output": ["output", "output", "output"], - # "info_from_input": [ - # "additional_info_1", - # "additional_info_1", - # "additional_info_1", - # ], - # "model_name": "test", - # "distilabel_metadata": [ - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_1", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_1", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_1", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # ], - # }, - # { - # "instruction": "test_2", - # "additional_info": "additional_info_2", - # "output": ["output", "output", "output"], - # "info_from_input": [ - # "additional_info_2", - # "additional_info_2", - # "additional_info_2", - # ], - # "model_name": "test", - # "distilabel_metadata": [ - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_2", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_2", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # { - # "raw_output_task": "output", - # "raw_input_task": [ - # { - # "content": "", - # "role": "system", - # }, - # { - # "content": "test_2", - # "role": "user", - # }, - # ], - # "statistics": {"input_tokens": 12, "output_tokens": 12}, - # }, - # ], - # }, - # ], - # ), + ( + [ + {"instruction": "test_0", "additional_info": "additional_info_0"}, + {"instruction": "test_1", "additional_info": "additional_info_1"}, + {"instruction": "test_2", "additional_info": "additional_info_2"}, + ], + True, + [ + { + "instruction": "test_0", + "additional_info": "additional_info_0", + "output": ["output", "output"], + "info_from_input": [ + "additional_info_0", + "additional_info_0", + ], + "model_name": "test", + "distilabel_metadata": [ + { + "raw_output_task": "output", + "raw_input_task": [ + { + "content": "", + "role": "system", + }, + { + "content": "test_0", + "role": "user", + }, + ], + "statistics": {"input_tokens": 12, "output_tokens": 12}, + }, + { + "raw_output_task": "output", + "raw_input_task": [ + { + "content": "", + "role": "system", + }, + { + "content": "test_0", + "role": "user", + }, + ], + "statistics": {"input_tokens": 12, "output_tokens": 12}, + }, + ], + }, + { + "instruction": "test_1", + "additional_info": "additional_info_1", + "output": ["output", "output"], + "info_from_input": [ + "additional_info_1", + "additional_info_1", + ], + "model_name": "test", + "distilabel_metadata": [ + { + "raw_output_task": "output", + "raw_input_task": [ + { + "content": "", + "role": "system", + }, + { + "content": "test_1", + "role": "user", + }, + ], + "statistics": {"input_tokens": 12, "output_tokens": 12}, + }, + { + "raw_output_task": "output", + "raw_input_task": [ + { + "content": "", + "role": "system", + }, + { + "content": "test_1", + "role": "user", + }, + ], + "statistics": {"input_tokens": 12, "output_tokens": 12}, + }, + ], + }, + { + "instruction": "test_2", + "additional_info": "additional_info_2", + "output": ["output", "output"], + "info_from_input": [ + "additional_info_2", + "additional_info_2", + ], + "model_name": "test", + "distilabel_metadata": [ + { + "raw_output_task": "output", + "raw_input_task": [ + { + "content": "", + "role": "system", + }, + { + "content": "test_2", + "role": "user", + }, + ], + "statistics": {"input_tokens": 12, "output_tokens": 12}, + }, + { + "raw_output_task": "output", + "raw_input_task": [ + { + "content": "", + "role": "system", + }, + { + "content": "test_2", + "role": "user", + }, + ], + "statistics": {"input_tokens": 12, "output_tokens": 12}, + }, + ], + }, + ], + ), ], ) def test_process( @@ -371,9 +330,11 @@ def test_process( input: List[Dict[str, str]], group_generations: bool, expected: List[Dict[str, Any]], + n_generations_supported: bool, ) -> None: pipeline = Pipeline(name="unit-test-pipeline") - llm = DummyAsyncLLM() + llm = DummyAsyncLLM(n_generations_supported=n_generations_supported) + llm.load() task = DummyTask( name="task", llm=llm, @@ -391,7 +352,7 @@ def test_process_overriding_inputs(self) -> None: name="task", llm=llm, group_generations=False, - num_generations=2, + num_generations=3, input_mappings={"instruction": "instruction_2"}, ) task.load() @@ -407,8 +368,6 @@ def test_process_overriding_inputs(self) -> None: ] ) ) - print("REUSLT", result) - assert result == [ { "additional_info": "info", @@ -536,117 +495,121 @@ def test_serialization(self) -> None: pipeline = Pipeline(name="unit-test-pipeline") llm = DummyAsyncLLM() task = DummyTask(name="task", llm=llm, pipeline=pipeline) - assert task.dump() == { - "name": "task", - "add_raw_output": True, - "add_raw_input": True, - "input_mappings": {}, - "output_mappings": {}, - "resources": { - "cpus": None, - "gpus": None, - "memory": None, - "replicas": 1, - "resources": None, - }, - "input_batch_size": 50, - "llm": { - "generation_kwargs": {}, - "structured_output": None, - "jobs_ids": None, - "offline_batch_generation_block_until_done": None, - "use_offline_batch_generation": False, - "type_info": { - "module": "tests.unit.conftest", - "name": "DummyAsyncLLM", - }, - }, - "group_generations": False, - "num_generations": 1, - "runtime_parameters_info": [ - { - "name": "resources", - "runtime_parameters_info": [ - { - "description": "The number of replicas for the step.", - "name": "replicas", - "optional": True, - }, - { - "description": "The number of CPUs assigned to each step replica.", - "name": "cpus", - "optional": True, - }, - { - "description": "The number of GPUs assigned to each step replica.", - "name": "gpus", - "optional": True, - }, - { - "description": "The memory in bytes required for each step replica.", - "name": "memory", - "optional": True, - }, - { - "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", - "name": "resources", - "optional": True, - }, - ], + assert ( + task.dump() + == { + "name": "task", + "add_raw_output": True, + "add_raw_input": True, + "input_mappings": {}, + "output_mappings": {}, + "resources": { + "cpus": None, + "gpus": None, + "memory": None, + "replicas": 1, + "resources": None, }, - { - "description": "The number of rows that will contain the batches processed by the step.", - "name": "input_batch_size", - "optional": True, - }, - { - "name": "llm", - "runtime_parameters_info": [ - { - "description": "The kwargs to be propagated to either `generate` or " - "`agenerate` methods within each `LLM`.", - "keys": [], - "name": "generation_kwargs", - }, - { - "description": "Whether to use the `offline_batch_generate` method to " - "generate the responses.", - "name": "use_offline_batch_generation", - "optional": True, - }, - { - "description": "If provided, then polling will be done until the " - "`ofline_batch_generate` method is able to retrieve the " - "results. The value indicate the time to wait between each " - "polling.", - "name": "offline_batch_generation_block_until_done", - "optional": True, - }, - ], - }, - { - "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", - "name": "add_raw_output", - "optional": True, - }, - { - "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column", - "name": "add_raw_input", - "optional": True, + "input_batch_size": 50, + "llm": { + "generation_kwargs": {}, + "structured_output": None, + "n_generations_supported": True, # Just a trick during testing, it won't appear otherwise + "jobs_ids": None, + "offline_batch_generation_block_until_done": None, + "use_offline_batch_generation": False, + "type_info": { + "module": "tests.unit.conftest", + "name": "DummyAsyncLLM", + }, }, - { - "name": "num_generations", - "description": "The number of generations to be produced per input.", - "optional": True, + "group_generations": False, + "num_generations": 1, + "runtime_parameters_info": [ + { + "name": "resources", + "runtime_parameters_info": [ + { + "description": "The number of replicas for the step.", + "name": "replicas", + "optional": True, + }, + { + "description": "The number of CPUs assigned to each step replica.", + "name": "cpus", + "optional": True, + }, + { + "description": "The number of GPUs assigned to each step replica.", + "name": "gpus", + "optional": True, + }, + { + "description": "The memory in bytes required for each step replica.", + "name": "memory", + "optional": True, + }, + { + "description": "A dictionary containing names of custom resources and the number of those resources required for each step replica.", + "name": "resources", + "optional": True, + }, + ], + }, + { + "description": "The number of rows that will contain the batches processed by the step.", + "name": "input_batch_size", + "optional": True, + }, + { + "name": "llm", + "runtime_parameters_info": [ + { + "description": "The kwargs to be propagated to either `generate` or " + "`agenerate` methods within each `LLM`.", + "keys": [], + "name": "generation_kwargs", + }, + { + "description": "Whether to use the `offline_batch_generate` method to " + "generate the responses.", + "name": "use_offline_batch_generation", + "optional": True, + }, + { + "description": "If provided, then polling will be done until the " + "`ofline_batch_generate` method is able to retrieve the " + "results. The value indicate the time to wait between each " + "polling.", + "name": "offline_batch_generation_block_until_done", + "optional": True, + }, + ], + }, + { + "description": "Whether to include the raw output of the LLM in the key `raw_output_` of the `distilabel_metadata` dictionary output column", + "name": "add_raw_output", + "optional": True, + }, + { + "description": "Whether to include the raw input of the LLM in the key `raw_input_` of the `distilabel_metadata` dictionary column", + "name": "add_raw_input", + "optional": True, + }, + { + "name": "num_generations", + "description": "The number of generations to be produced per input.", + "optional": True, + }, + ], + "use_cache": True, + "type_info": { + "module": "tests.unit.conftest", + "name": "DummyTask", }, - ], - "use_cache": True, - "type_info": { - "module": "tests.unit.conftest", - "name": "DummyTask", - }, - "use_default_structured_output": False, - } + "use_default_structured_output": False, + } + ) with Pipeline(name="unit-test-pipeline") as pipeline: new_task = DummyTask.from_dict(task.dump()) From 6f6769a14f9acde5433e8c6222d57c6e937f12a8 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 23 Oct 2024 09:52:58 +0200 Subject: [PATCH 18/35] Move import to guarded type hint --- src/distilabel/llms/cohere.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/llms/cohere.py index 8dba8a8c90..959c7a58b0 100644 --- a/src/distilabel/llms/cohere.py +++ b/src/distilabel/llms/cohere.py @@ -25,7 +25,6 @@ import orjson from pydantic import Field, PrivateAttr, SecretStr, validate_call -from tokenizers import Tokenizer from distilabel.llms.base import AsyncLLM from distilabel.llms.typing import GenerateOutput @@ -39,6 +38,7 @@ if TYPE_CHECKING: from cohere import AsyncClient, ChatMessage, Message from pydantic import BaseModel + from tokenizers import Tokenizer from distilabel.llms.typing import LLMStatistics From 4971f262792224bf250c6e9fe959ae6f9abd8061 Mon Sep 17 00:00:00 2001 From: plaguss Date: Wed, 23 Oct 2024 16:13:31 +0200 Subject: [PATCH 19/35] Fix tests to work with statistics --- src/distilabel/steps/tasks/base.py | 3 +- .../steps/tasks/evol_instruct/base.py | 51 +++++-- .../steps/tasks/evol_instruct/generator.py | 56 ++++++-- .../steps/tasks/evol_quality/base.py | 23 +++- .../unit/steps/tasks/apigen/test_generator.py | 13 +- .../steps/tasks/evol_instruct/test_base.py | 19 +++ .../tasks/evol_instruct/test_generator.py | 28 +++- .../steps/tasks/evol_quality/test_base.py | 13 ++ tests/unit/steps/tasks/test_decorator.py | 8 +- .../tasks/test_improving_text_embeddings.py | 127 ++++++++++++++++-- .../tasks/test_instruction_backtranslation.py | 16 ++- .../steps/tasks/test_sentence_transformers.py | 21 --- .../steps/tasks/test_structured_generation.py | 15 ++- .../steps/tasks/test_text_classification.py | 17 ++- .../unit/steps/tasks/test_text_generation.py | 6 +- tests/unit/steps/tasks/test_ultrafeedback.py | 23 ++-- 16 files changed, 351 insertions(+), 88 deletions(-) diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index d879d13536..df012ffe2a 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -245,6 +245,8 @@ def _create_metadata( if add_raw_input: meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None if statistics: + # TODO: STATISTICS SHOULD BE GENERATED USING THE STEP NAME TO AVOID OVERWRITING THEM + # meta[f"statistics_{self.name}"] = statistics meta["statistics"] = statistics if meta: output[DISTILABEL_METADATA_KEY] = meta @@ -427,7 +429,6 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore formatted_inputs = self._format_inputs(inputs) - # `outputs` is a list containing a list of generations per input # `outputs` is a dict containing the LLM outputs in the `generations` # key and the statistics in the `statistics` key outputs = self.llm.generate_outputs( diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py index 95f271a117..8cae197324 100644 --- a/src/distilabel/steps/tasks/evol_instruct/base.py +++ b/src/distilabel/steps/tasks/evol_instruct/base.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import numpy as np from pydantic import Field @@ -26,6 +27,7 @@ from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: + from distilabel.llms.typing import LLMStatistics from distilabel.steps.typing import StepOutput @@ -267,6 +269,7 @@ def _evolve_instructions(self, inputs: "StepInput") -> List[List[str]]: """ instructions: List[List[str]] = [[input["instruction"]] for input in inputs] + statistics: "LLMStatistics" = defaultdict(list) for iter_no in range(self.num_evolutions): formatted_prompts = [] @@ -276,12 +279,16 @@ def _evolve_instructions(self, inputs: "StepInput") -> List[List[str]]: formatted_prompts = [ self.format_input(prompt) for prompt in formatted_prompts ] + responses = self.llm.generate( + formatted_prompts, + **self.llm.generation_kwargs, # type: ignore + ) generated_prompts = flatten_responses( - self.llm.generate( - formatted_prompts, - **self.llm.generation_kwargs, # type: ignore - ) + [response["generations"] for response in responses] ) + for response in responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) evolved_instructions = [] for generated_prompt in generated_prompts: @@ -304,12 +311,11 @@ def _evolve_instructions(self, inputs: "StepInput") -> List[List[str]]: self._logger.info( f"🔄 Ran iteration {iter_no} evolving {len(instructions)} instructions!" ) - - return instructions + return instructions, dict(statistics) def _generate_answers( self, evolved_instructions: List[List[str]] - ) -> List[List[str]]: + ) -> Tuple[List[List[str]], "LLMStatistics"]: """Generates the answer for the instructions in `instructions`. Args: @@ -331,16 +337,23 @@ def _generate_answers( num_generations=1, **self.llm.generation_kwargs, # type: ignore ) + generations = [response["generations"] for response in responses] + + statistics: Dict[str, Any] = defaultdict(list) + for response in responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) step = ( self.num_evolutions if not self.include_original_instruction else self.num_evolutions + 1 ) + return [ - flatten_responses(responses[i : i + step]) + flatten_responses(generations[i : i + step]) for i in range(0, len(responses), step) - ] + ], dict(statistics) @override def process(self, inputs: StepInput) -> "StepOutput": # type: ignore @@ -353,7 +366,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore A list of Python dictionaries with the outputs of the task. """ - evolved_instructions = self._evolve_instructions(inputs) + evolved_instructions, statistics = self._evolve_instructions(inputs) if self.store_evolutions: # Remove the input instruction from the `evolved_instructions` list @@ -365,6 +378,13 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore if not self.generate_answers: for input, instruction in zip(inputs, evolved_instructions): input.update(self.format_output(instruction)) + input.update( + { + "distilabel_metadata": { + f"statistics_instruction_{self.name}": statistics + } + } + ) yield inputs self._logger.info( @@ -376,7 +396,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore f"🧠 Generating answers for the {len(evolved_instructions)} evolved instructions!" ) - answers = self._generate_answers(evolved_instructions) + answers, statistics = self._generate_answers(evolved_instructions) self._logger.info( f"🎉 Finished generating answers for the {len(evolved_instructions)} evolved" @@ -387,6 +407,13 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore zip(inputs, evolved_instructions) ): input.update(self.format_output(instruction, answers[idx])) + input.update( + { + "distilabel_metadata": { + f"statistics_answer_{self.name}": statistics + } + } + ) yield inputs @override diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py index 1f56c866a3..ba618330f0 100644 --- a/src/distilabel/steps/tasks/evol_instruct/generator.py +++ b/src/distilabel/steps/tasks/evol_instruct/generator.py @@ -19,8 +19,9 @@ else: import importlib.resources as importlib_resources +from collections import defaultdict from functools import cached_property -from typing import TYPE_CHECKING, Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import numpy as np from pydantic import Field, PrivateAttr @@ -32,6 +33,7 @@ from distilabel.utils.lists import flatten_responses if TYPE_CHECKING: + from distilabel.llms.typing import LLMStatistics from distilabel.steps.tasks.typing import ChatType from distilabel.steps.typing import GeneratorStepOutput @@ -256,7 +258,9 @@ def _apply_random_mutation(self, iter_no: int) -> List["ChatType"]: prompts.append([{"role": "user", "content": prompt_with_template}]) return prompts - def _generate_answers(self, instructions: List[List[str]]) -> List[str]: + def _generate_answers( + self, instructions: List[List[str]] + ) -> Tuple[List[str], "LLMStatistics"]: """Generates the answer for the last instruction in `instructions`. Args: @@ -276,10 +280,17 @@ def _generate_answers(self, instructions: List[List[str]]) -> List[str]: _formatted_instructions, **self.llm.generation_kwargs, # type: ignore ) - return flatten_responses(responses) + statistics: Dict[str, Any] = defaultdict(list) + for response in responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) + + return flatten_responses( + [response["generations"] for response in responses] + ), dict(statistics) @override - def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore + def process(self, offset: int = 0) -> "GeneratorStepOutput": # NOQA: C901, type: ignore """Processes the inputs of the task and generates the outputs using the LLM. Args: @@ -297,9 +308,17 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore while len(instructions) < self.num_instructions: prompts = self._apply_random_mutation(iter_no=iter_no) + # TODO: Update the function to extract from the dict + responses = self.llm.generate(prompts, **self.llm.generation_kwargs) # type: ignore + generated_prompts = flatten_responses( - self.llm.generate(prompts, **self.llm.generation_kwargs) # type: ignore + [response["generations"] for response in responses] ) + statistics: "LLMStatistics" = defaultdict(list) + for response in responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) + for idx, generated_prompt in enumerate(generated_prompts): generated_prompt = generated_prompt.split("Prompt#:")[-1].strip() if self.max_length >= len(generated_prompt) >= self.min_length: # type: ignore @@ -319,11 +338,15 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore mutation_no = len(instructions) - mutation_no if not self.generate_answers and len(instructions[-mutation_no:]) > 0: + formatted_generations = [] + for mutated_instruction in instructions[-mutation_no:]: + mutated_instruction = self.format_output(mutated_instruction) + mutated_instruction["distilabel_metadata"] = { + f"statistics_instruction_{self.name}": dict(statistics) + } + formatted_generations.append(mutated_instruction) yield ( - [ - self.format_output(mutated_instruction) - for mutated_instruction in instructions[-mutation_no:] - ], + formatted_generations, len(instructions) >= self.num_instructions, ) @@ -334,17 +357,22 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore f"🧠 Generating answers for the {len(instructions)} evolved instructions!" ) - answers = self._generate_answers(instructions) + answers, statistics = self._generate_answers(instructions) self._logger.info( f"🎉 Finished generating answers for the {len(instructions)} evolved instructions!" ) + formatted_outputs = [] + for instruction, answer in zip(instructions, answers): + formatted_output = self.format_output(instruction, answer) + formatted_output["distilabel_metadata"] = { + f"statistics_answer_{self.name}": dict(statistics) + } + formatted_outputs.append(formatted_output) + yield ( - [ - self.format_output(instruction, answer) - for instruction, answer in zip(instructions, answers) - ], + formatted_outputs, True, ) diff --git a/src/distilabel/steps/tasks/evol_quality/base.py b/src/distilabel/steps/tasks/evol_quality/base.py index 5c899aa680..d1e8d881eb 100644 --- a/src/distilabel/steps/tasks/evol_quality/base.py +++ b/src/distilabel/steps/tasks/evol_quality/base.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Any, Dict, List, Union +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union import numpy as np from pydantic import Field @@ -200,7 +201,9 @@ def _apply_random_mutation(self, instruction: str, response: str) -> str: .replace("", response) ) - def _evolve_reponses(self, inputs: "StepInput") -> List[List[str]]: + def _evolve_reponses( + self, inputs: "StepInput" + ) -> Tuple[List[List[str]], Dict[str, Any]]: """Evolves the instructions provided as part of the inputs of the task. Args: @@ -213,6 +216,7 @@ def _evolve_reponses(self, inputs: "StepInput") -> List[List[str]]: np.random.seed(self.seed) instructions: List[List[str]] = [[input["instruction"]] for input in inputs] responses: List[List[str]] = [[input["response"]] for input in inputs] + statistics: Dict[str, Any] = defaultdict(list) for iter_no in range(self.num_evolutions): formatted_prompts = [] @@ -229,24 +233,28 @@ def _evolve_reponses(self, inputs: "StepInput") -> List[List[str]]: formatted_prompts, **self.llm.generation_kwargs, # type: ignore ) + for response in generated_responses: + for k, v in response["statistics"].items(): + statistics[k].append(v[0]) if self.store_evolutions: responses = [ - response + [evolved_response[0]] + response + [evolved_response["generations"][0]] for response, evolved_response in zip( responses, generated_responses ) ] else: responses = [ - [evolved_response[0]] for evolved_response in generated_responses + [evolved_response["generations"][0]] + for evolved_response in generated_responses ] self._logger.info( f"🔄 Ran iteration {iter_no} evolving {len(responses)} responses!" ) - return responses + return responses, dict(statistics) @override def process(self, inputs: StepInput) -> "StepOutput": # type: ignore @@ -259,7 +267,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore A list of Python dictionaries with the outputs of the task. """ - responses = self._evolve_reponses(inputs) + responses, statistics = self._evolve_reponses(inputs) if self.store_evolutions: # Remove the input instruction from the `evolved_responses` list @@ -268,6 +276,9 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore for input, response in zip(inputs, responses): input.update(self.format_output(response)) + input.update( + {"distilabel_metadata": {f"statistics_{self.name}": statistics}} + ) yield inputs self._logger.info(f"🎉 Finished evolving {len(responses)} instructions!") diff --git a/tests/unit/steps/tasks/apigen/test_generator.py b/tests/unit/steps/tasks/apigen/test_generator.py index a290666a60..63e6403915 100644 --- a/tests/unit/steps/tasks/apigen/test_generator.py +++ b/tests/unit/steps/tasks/apigen/test_generator.py @@ -49,9 +49,16 @@ def generate( if self.use_structured_output: query_answers = {"pairs": query_answers} return [ - [json.dumps(query_answers) for _ in range(num_generations)] - for _ in range(len(inputs)) - ] + { + "generations": [ + json.dumps(query_answers) for _ in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) # Example of 3 rows from Salesforce/xlam-function-calling-60k diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py index 66f67347b1..478443f4ee 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_base.py +++ b/tests/unit/steps/tasks/evol_instruct/test_base.py @@ -69,6 +69,12 @@ def test_process(self, dummy_llm: LLM) -> None: "instruction": "test", "evolved_instruction": "output", "model_name": "test", + "distilabel_metadata": { + "statistics_instruction_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, } ] ] @@ -89,6 +95,12 @@ def test_process_store_evolutions(self, dummy_llm: LLM) -> None: "instruction": "test", "evolved_instructions": ["output", "output"], "model_name": "test", + "distilabel_metadata": { + "statistics_instruction_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, } ] ] @@ -110,6 +122,12 @@ def test_process_generate_answers(self, dummy_llm: LLM) -> None: "evolved_instruction": "output", "answer": "output", "model_name": "test", + "distilabel_metadata": { + "statistics_answer_task": { + "input_tokens": [12], + "output_tokens": [12], + } + }, } ] ] @@ -140,6 +158,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: "jobs_ids": None, "offline_batch_generation_block_until_done": None, "use_offline_batch_generation": False, + "n_generations_supported": True, "type_info": { "module": task.llm.__module__, "name": task.llm.__class__.__name__, diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py index 8f86b94908..4148101fc6 100644 --- a/tests/unit/steps/tasks/evol_instruct/test_generator.py +++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py @@ -64,19 +64,36 @@ def test_process(self, dummy_llm: LLM) -> None: task = EvolInstructGenerator( name="task", llm=dummy_llm, - num_instructions=1, + num_instructions=2, min_length=1, max_length=10, pipeline=pipeline, ) task.load() + assert list(task.process()) == [ ( [ { "instruction": "output", "model_name": "test", - } + "distilabel_metadata": { + "statistics_instruction_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, + }, + { + "instruction": "output", + "model_name": "test", + "distilabel_metadata": { + "statistics_instruction_task": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, + }, ], True, ) @@ -101,6 +118,12 @@ def test_process_generate_answers(self, dummy_llm: LLM) -> None: "instruction": "output", "answer": "output", "model_name": "test", + "distilabel_metadata": { + "statistics_answer_task": { + "input_tokens": [12], + "output_tokens": [12], + } + }, } ], True, @@ -122,6 +145,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: "jobs_ids": None, "offline_batch_generation_block_until_done": None, "use_offline_batch_generation": False, + "n_generations_supported": True, "type_info": { "module": task.llm.__class__.__module__, "name": task.llm.__class__.__name__, diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py index 2ac460afc4..586b6b1f99 100644 --- a/tests/unit/steps/tasks/evol_quality/test_base.py +++ b/tests/unit/steps/tasks/evol_quality/test_base.py @@ -60,6 +60,12 @@ def test_process(self, dummy_llm: LLM) -> None: "response": "mock", "evolved_response": "output", "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, } ] ] @@ -81,6 +87,12 @@ def test_process_store_evolutions(self, dummy_llm: LLM) -> None: "response": "mock", "evolved_responses": ["output", "output"], "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, } ] ] @@ -111,6 +123,7 @@ def test_serialization(self, dummy_llm: LLM) -> None: "jobs_ids": None, "offline_batch_generation_block_until_done": None, "use_offline_batch_generation": False, + "n_generations_supported": True, "type_info": { "module": task.llm.__module__, "name": task.llm.__class__.__name__, diff --git a/tests/unit/steps/tasks/test_decorator.py b/tests/unit/steps/tasks/test_decorator.py index 085153c1f8..0f53d6d235 100644 --- a/tests/unit/steps/tasks/test_decorator.py +++ b/tests/unit/steps/tasks/test_decorator.py @@ -181,7 +181,7 @@ def MyTask( { "task": "summarize", "instruction": "The cell...", - "response": "output", + "response": "output 0", "model_name": "test", "distilabel_metadata": { "raw_input_my_task_0": [ @@ -194,7 +194,11 @@ def MyTask( "role": "user", }, ], - "raw_output_my_task_0": "output", + "raw_output_my_task_0": "output 0", + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + }, }, } ] diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py index dfaa247b91..c73ba86107 100644 --- a/tests/unit/steps/tasks/test_improving_text_embeddings.py +++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py @@ -45,7 +45,15 @@ def model_name(self) -> str: def generate( # type: ignore self, inputs: List[ChatType], num_generations: int = 1 ) -> List[GenerateOutput]: - return [[self.output] for _ in range(num_generations)] + return [ + { + "generations": [self.output for _ in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) class TestEmbeddingTaskGenerator: @@ -74,13 +82,54 @@ def test_process(self, category: str, flatten_tasks: bool) -> None: assert task.outputs == ["tasks" if not flatten_tasks else "task", "model_name"] result = ( - ([{"tasks": ["A", "B", "C"], "model_name": "test"}], True) + ( + [ + { + "tasks": ["A", "B", "C"], + "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } + ], + True, + ) if not flatten_tasks else ( [ - {"task": "A", "model_name": "test"}, - {"task": "B", "model_name": "test"}, - {"task": "C", "model_name": "test"}, + { + "task": "A", + "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + }, + { + "task": "B", + "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + }, + { + "task": "C", + "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + }, ], True, ) @@ -131,7 +180,20 @@ def test_process(self) -> None: assert task.outputs == ["S1", "S2", "S3", "model_name"] assert next(task.process()) == ( - [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}], + [ + { + "S1": "A", + "S2": "B", + "S3": "C", + "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } + ], True, ) @@ -192,7 +254,20 @@ def test_process(self) -> None: task.load() assert task.outputs == ["S1", "S2", "S3", "model_name"] assert next(task.process()) == ( - [{"S1": "A", "S2": "B", "S3": "C", "model_name": "test"}], + [ + { + "S1": "A", + "S2": "B", + "S3": "C", + "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } + ], True, ) @@ -241,7 +316,18 @@ def test_process(self) -> None: assert task.outputs == ["input", "positive_document", "model_name"] assert next(task.process(inputs=[{"task": "A"}])) == [ - {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"} + { + "task": "A", + "input": "A", + "positive_document": "B", + "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } ] @@ -271,7 +357,18 @@ def test_process(self) -> None: task.load() assert task.outputs == ["input", "positive_document", "model_name"] assert next(task.process(inputs=[{"task": "A"}])) == [ - {"task": "A", "input": "A", "positive_document": "B", "model_name": "test"} + { + "task": "A", + "input": "A", + "positive_document": "B", + "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, + } ] def test_reproducibility(self) -> None: @@ -333,6 +430,12 @@ def test_process(self) -> None: "label": "B", "misleading_label": "C", "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, } ] @@ -410,5 +513,11 @@ def test_process(self) -> None: "positive_document": "B", "hard_negative_document": "C", "model_name": "test", + "distilabel_metadata": { + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + } + }, } ] diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py index 1b2f9adffa..ef34bca511 100644 --- a/tests/unit/steps/tasks/test_instruction_backtranslation.py +++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py @@ -35,7 +35,15 @@ def generate( self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any ) -> List[GenerateOutput]: return [ - ["This is the reason. Score: 1" for _ in range(num_generations)] + { + "generations": [ + "This is the reason. Score: 1" for _ in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } for _ in inputs ] @@ -88,7 +96,11 @@ def test_process(self) -> None: "reason": "This is the reason.", "model_name": "instruction-backtranslation-model", "distilabel_metadata": { - "raw_output_instruction-backtranslation": "This is the reason. Score: 1" + "raw_output_instruction-backtranslation": "This is the reason. Score: 1", + "statistics": { + "input_tokens": 12, + "output_tokens": 12, + }, }, } ] diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py index 9dc6b38ae1..8df92e903d 100644 --- a/tests/unit/steps/tasks/test_sentence_transformers.py +++ b/tests/unit/steps/tasks/test_sentence_transformers.py @@ -26,27 +26,6 @@ ) from tests.unit.conftest import DummyAsyncLLM -# from distilabel.llms.base import LLM, AsyncLLM - -# if TYPE_CHECKING: -# from distilabel.llms.typing import GenerateOutput -# from distilabel.steps.tasks.typing import FormattedInput - -# # Defined here too, so that the serde still works -# class DummyStructuredLLM(LLM): -# structured_output: Any = None -# def load(self) -> None: -# pass - -# @property -# def model_name(self) -> str: -# return "test" - -# def generate( -# self, input: "FormattedInput", num_generations: int = 1 -# ) -> "GenerateOutput": -# return ['{ \n "negative": "negative",\n "positive": "positive"\n}' for _ in range(num_generations)] - class TestGenerateSentencePair: @pytest.mark.parametrize( diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py index a57d0da7df..3c85caa47b 100644 --- a/tests/unit/steps/tasks/test_structured_generation.py +++ b/tests/unit/steps/tasks/test_structured_generation.py @@ -37,7 +37,15 @@ def generate( # type: ignore self, inputs: List["StructuredInput"], num_generations: int = 1, **kwargs: Any ) -> List["GenerateOutput"]: return [ - [json.dumps({"test": "output"}) for _ in range(num_generations)] + { + "generations": [ + json.dumps({"test": "output"}) for _ in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } for _ in inputs ] @@ -123,6 +131,9 @@ def test_process(self) -> None: }, "generation": '{"test": "output"}', "model_name": "test", - "distilabel_metadata": {"raw_output_task": '{"test": "output"}'}, + "distilabel_metadata": { + "raw_output_task": '{"test": "output"}', + "statistics": {"input_tokens": 12, "output_tokens": 12}, + }, } ] diff --git a/tests/unit/steps/tasks/test_text_classification.py b/tests/unit/steps/tasks/test_text_classification.py index e5af171b33..739e8a4ffe 100644 --- a/tests/unit/steps/tasks/test_text_classification.py +++ b/tests/unit/steps/tasks/test_text_classification.py @@ -32,11 +32,18 @@ async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": if self.n == 1: - return [json.dumps({"labels": "label"}) for _ in range(num_generations)] - return [ - json.dumps({"labels": [f"label_{i}" for i in range(self.n)]}) - for _ in range(num_generations) - ] + labels = "label" + else: + labels = ["label_0", "label_1", "label_2"] + return { + "generations": [ + json.dumps({"labels": labels}) for _ in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } class TestTextClassification: diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index 2a6abefb22..7262f2c408 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -103,6 +103,7 @@ def test_process(self) -> None: "model_name": "test", "distilabel_metadata": { "raw_output_task": "output", + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, } ] @@ -230,6 +231,9 @@ def test_process(self) -> None: "messages": [{"role": "user", "content": "Tell me a joke."}], "generation": "output", "model_name": "test", - "distilabel_metadata": {"raw_output_task": "output"}, + "distilabel_metadata": { + "raw_output_task": "output", + "statistics": {"input_tokens": 12, "output_tokens": 12}, + }, } ] diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py index 5565065d61..1adf58fe94 100644 --- a/tests/unit/steps/tasks/test_ultrafeedback.py +++ b/tests/unit/steps/tasks/test_ultrafeedback.py @@ -36,12 +36,17 @@ def generate( self, inputs: List[ChatType], num_generations: int = 1, **kwargs: Any ) -> List[GenerateOutput]: return [ - [ - "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" - for _ in range(num_generations) - ] - for _ in inputs - ] + { + "generations": [ + "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + for i in range(num_generations) + ], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) class TestUltraFeedback: @@ -65,7 +70,8 @@ def test_process_with_simple_aspect(self) -> None: "rationales": ["text", "text"], "model_name": "ultrafeedback-model", "distilabel_metadata": { - "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text", + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, } ] @@ -92,7 +98,8 @@ def test_process_with_complex_aspect(self) -> None: "rationales-for-ratings": ["text", "text"], "model_name": "ultrafeedback-model", "distilabel_metadata": { - "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text" + "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text", + "statistics": {"input_tokens": 12, "output_tokens": 12}, }, } ] From 74f81adb1c62af23e1373cc6fc85065acf319ac7 Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 13:18:38 +0200 Subject: [PATCH 20/35] Return void list in case of no generations --- src/distilabel/llms/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/distilabel/llms/utils.py b/src/distilabel/llms/utils.py index 2661e7e207..d0d411c90f 100644 --- a/src/distilabel/llms/utils.py +++ b/src/distilabel/llms/utils.py @@ -59,7 +59,7 @@ def prepare_output( return { "generations": generations, "statistics": { - "input_tokens": input_tokens or 0, - "output_tokens": output_tokens or 0, + "input_tokens": input_tokens or [], + "output_tokens": output_tokens or [], }, } From 8ff6e138173b85ec440018bd89dcf89aacd76cd8 Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 13:19:52 +0200 Subject: [PATCH 21/35] Update function to allow flatten inner list in values of dicts, and add new merge_dicts to help merging user-assistant messages in magpie --- src/distilabel/utils/dicts.py | 64 +++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/src/distilabel/utils/dicts.py b/src/distilabel/utils/dicts.py index 6c651ae32a..a6eb07831f 100644 --- a/src/distilabel/utils/dicts.py +++ b/src/distilabel/utils/dicts.py @@ -14,17 +14,19 @@ import json from collections import defaultdict +from itertools import chain from typing import Any, Dict, List, TypeVar _K = TypeVar("_K") -def group_dicts(*dicts: Dict[_K, Any]) -> Dict[_K, List[Any]]: +def group_dicts(*dicts: Dict[_K, Any], flatten: bool = False) -> Dict[_K, List[Any]]: """Combines multiple dictionaries into a single dictionary joining the values as a list for each key. Args: *dicts: the dictionaries to be combined. + flatten: whether to flatten the list of values for each key. Returns: The combined dictionary. @@ -33,8 +35,66 @@ def group_dicts(*dicts: Dict[_K, Any]) -> Dict[_K, List[Any]]: for d in dicts: for key, value in d.items(): combined_dict[key].append(value) - return dict(combined_dict) + + combined_dict = dict(combined_dict) + if flatten: + combined_dict = { + k: list(chain.from_iterable(v)) for k, v in combined_dict.items() + } + return combined_dict def flatten_dict(x: Dict[Any, Any]) -> Dict[Any, Any]: return {k: json.dumps(v) if isinstance(v, dict) else v for k, v in x.items()} + + +def merge_dicts(*dict_lists): + """ + Merge N lists of dictionaries with matching keys. + The keys can be any strings, but they must match across all dictionaries within each position. + + Args: + *dict_lists: Variable number of lists of dictionaries + + Returns: + list: Merged list of dictionaries with combined values + + Raises: + ValueError: If lists have different lengths or dictionaries have mismatched keys + """ + if not dict_lists: + return [] + + # Verify all lists have the same length + first_len = len(dict_lists[0]) + if not all(len(d) == first_len for d in dict_lists): + raise ValueError("All input lists must have the same length") + + # For each position, get keys from first list's dictionary + result = [] + for i in range(first_len): + # Get keys from the first dictionary at this position + keys = set(dict_lists[0][i].keys()) + + # Verify all dictionaries at this position have the same keys + for dict_list in dict_lists: + if set(dict_list[i].keys()) != keys: + raise ValueError( + f"All dictionaries at position {i} must have the same keys" + ) + + merged_dict = {key: [] for key in keys} + + # For each dictionary at position i in all lists + for dict_list in dict_lists: + current_dict = dict_list[i] + for key in keys: + # Ensure value is a list + value = current_dict[key] + if not isinstance(value, list): + value = [value] + merged_dict[key].extend(value) + + result.append(merged_dict) + + return result From d8f2a8b1fb68cd1f646c9f3c17d8a0dd1c2286d5 Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 13:20:23 +0200 Subject: [PATCH 22/35] Fix dummy magpie llm --- tests/unit/conftest.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 87343cf7bf..7bda619913 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -87,16 +87,13 @@ def generate( self, inputs: List["FormattedInput"], num_generations: int = 1, **kwargs: Any ) -> List["GenerateOutput"]: return [ - [ - { - "generations": ["output"] * num_generations, - "statistics": { - "input_tokens": [12] * num_generations, - "output_tokens": [12] * num_generations, - }, - } - for _ in range(num_generations) - ] + { + "generations": ["Hello Magpie"] * num_generations, + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } for _ in range(len(inputs)) ] From 70898dab6368efcf63d21add643377c5553f23cb Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 13:20:55 +0200 Subject: [PATCH 23/35] Update tests for magpie --- tests/unit/steps/tasks/magpie/test_base.py | 227 +++++++++++++++++++-- 1 file changed, 207 insertions(+), 20 deletions(-) diff --git a/tests/unit/steps/tasks/magpie/test_base.py b/tests/unit/steps/tasks/magpie/test_base.py index cc13681f9f..ce83536427 100644 --- a/tests/unit/steps/tasks/magpie/test_base.py +++ b/tests/unit/steps/tasks/magpie/test_base.py @@ -86,16 +86,34 @@ def test_process(self) -> None: "instruction": "Hello Magpie", "response": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, { "instruction": "Hello Magpie", "response": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, { "instruction": "Hello Magpie", "response": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, ] @@ -119,6 +137,12 @@ def test_process_with_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -129,6 +153,12 @@ def test_process_with_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -139,6 +169,12 @@ def test_process_with_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -167,6 +203,12 @@ def test_process_with_several_system_prompts(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -177,6 +219,12 @@ def test_process_with_several_system_prompts(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -187,6 +235,12 @@ def test_process_with_several_system_prompts(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -194,10 +248,42 @@ def test_process_failing_generation_for_some_rows(self) -> None: with mock.patch( "tests.unit.conftest.DummyMagpieLLM.generate", side_effect=[ - [["Hello Magpie"], [None], ["Hello Magpie"]], - [["Hello Magpie"], ["Hello Magpie"]], - [["Hello Magpie"], [None]], - [["Hello Magpie"]], + [ + { + "generations": ["Hello Magpie user"], + "statistics": { + "input_tokens": [12], + "output_tokens": [12], + }, + } + ], + [ + { + "generations": [None], + "statistics": { + "input_tokens": [], + "output_tokens": [], + }, + } + ], + [ + { + "generations": [None], + "statistics": { + "input_tokens": [], + "output_tokens": [], + }, + } + ], + [ + { + "generations": ["Hello Magpie assistant"], + "statistics": { + "input_tokens": [12], + "output_tokens": [12], + }, + } + ], ], ): task = Magpie( @@ -206,26 +292,19 @@ def test_process_failing_generation_for_some_rows(self) -> None: task.load() - assert next(task.process(inputs=[{}, {}, {}])) == [ - { - "conversation": [ - {"role": "user", "content": "Hello Magpie"}, - {"role": "assistant", "content": "Hello Magpie"}, - {"role": "user", "content": "Hello Magpie"}, - {"role": "assistant", "content": "Hello Magpie"}, - ], - "model_name": "test", - }, - { - "conversation": [], - "model_name": "test", - }, + assert next(task.process(inputs=[{}])) == [ { "conversation": [ - {"role": "user", "content": "Hello Magpie"}, - {"role": "assistant", "content": "Hello Magpie"}, + {"role": "user", "content": "Hello Magpie user"}, + {"role": "assistant", "content": "Hello Magpie assistant"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, ] @@ -243,6 +322,12 @@ def test_process_with_n_turns(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -252,6 +337,12 @@ def test_process_with_n_turns(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -261,6 +352,12 @@ def test_process_with_n_turns(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -281,6 +378,12 @@ def test_process_with_end_with_user(self) -> None: {"role": "user", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12], + "output_tokens": [12, 12, 12], + } + }, }, { "conversation": [ @@ -289,6 +392,12 @@ def test_process_with_end_with_user(self) -> None: {"role": "user", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12], + "output_tokens": [12, 12, 12], + } + }, }, { "conversation": [ @@ -297,6 +406,12 @@ def test_process_with_end_with_user(self) -> None: {"role": "user", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12], + "output_tokens": [12, 12, 12], + } + }, }, ] @@ -319,6 +434,12 @@ def test_process_with_include_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -329,6 +450,12 @@ def test_process_with_include_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "conversation": [ @@ -339,6 +466,12 @@ def test_process_with_include_system_prompt(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -370,6 +503,12 @@ def test_process_with_system_prompt_per_row(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "system_prompt": "You're a florist expert assistant.", @@ -381,6 +520,12 @@ def test_process_with_system_prompt_per_row(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, { "system_prompt": "You're a plumber expert assistant.", @@ -392,6 +537,12 @@ def test_process_with_system_prompt_per_row(self) -> None: {"role": "assistant", "content": "Hello Magpie"}, ], "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12, 12, 12], + "output_tokens": [12, 12, 12, 12], + } + }, }, ] @@ -420,18 +571,36 @@ def test_process_with_system_prompt_and_probabilities(self) -> None: "response": "Hello Magpie", "system_prompt_key": "system_prompt_1", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, { "instruction": "Hello Magpie", "response": "Hello Magpie", "system_prompt_key": "system_prompt_2", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, { "instruction": "Hello Magpie", "response": "Hello Magpie", "system_prompt_key": "system_prompt_1", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12, 12], + "output_tokens": [12, 12], + } + }, }, ] @@ -447,14 +616,32 @@ def test_process_only_instruction(self) -> None: { "instruction": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12], + "output_tokens": [12], + } + }, }, { "instruction": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12], + "output_tokens": [12], + } + }, }, { "instruction": "Hello Magpie", "model_name": "test", + "distilabel_metadata": { + "statistics_magpie_0": { + "input_tokens": [12], + "output_tokens": [12], + } + }, }, ] From 314e17136b6a02be096eb2a182500f4490528e7e Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 13:21:41 +0200 Subject: [PATCH 24/35] Create statistics entry in distilabel_metadata with the name of the step --- src/distilabel/steps/tasks/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index df012ffe2a..5fa9865ecf 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -245,9 +245,7 @@ def _create_metadata( if add_raw_input: meta[f"raw_input_{self.name}"] = self.format_input(input) if input else None if statistics: - # TODO: STATISTICS SHOULD BE GENERATED USING THE STEP NAME TO AVOID OVERWRITING THEM - # meta[f"statistics_{self.name}"] = statistics - meta["statistics"] = statistics + meta[f"statistics_{self.name}"] = statistics if meta: output[DISTILABEL_METADATA_KEY] = meta From e3e81d9ada56b943b73126589d99b7dfb3567bce Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 13:22:16 +0200 Subject: [PATCH 25/35] Update magpie code to work with the new llm.generate behaviour --- src/distilabel/steps/tasks/magpie/base.py | 77 +++++++++++++++++------ 1 file changed, 57 insertions(+), 20 deletions(-) diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py index a137d931dd..46348ed10d 100644 --- a/src/distilabel/steps/tasks/magpie/base.py +++ b/src/distilabel/steps/tasks/magpie/base.py @@ -27,11 +27,14 @@ ) from distilabel.steps.base import StepInput from distilabel.steps.tasks.base import Task +from distilabel.utils.dicts import merge_dicts if TYPE_CHECKING: + from distilabel.llms.typing import LLMStatistics from distilabel.steps.tasks.typing import ChatType from distilabel.steps.typing import StepColumns, StepOutput + MAGPIE_MULTI_TURN_SYSTEM_PROMPT = ( "You are a helpful Al assistant. The user will engage in a multi−round conversation" " with you, asking initial questions and following up with additional related questions." @@ -192,15 +195,25 @@ def _generate_instruction( num_generations=1, **self.llm.generation_kwargs, # type: ignore ) + stats = [] rows = [] for output, system_prompt_key in zip_longest( outputs, system_prompt_keys, fillvalue=None ): - row = {"instruction": output[0]} # type: ignore + row = { + "instruction": output["generations"][0], + "distilabel_metadata": { + f"statistics_{self.name}": output["statistics"] + }, + } # type: ignore if system_prompt_key is not None: row["system_prompt_key"] = system_prompt_key rows.append(row) - return rows + stats.append( + {} + ) # Mimics the stats to keep _generate_with_pre_query_template + + return rows, stats def _prepare_conversation_outputs( self, conversations: List["ChatType"], system_prompt_keys: List[str] @@ -245,18 +258,21 @@ def _prepare_conversation_outputs( def _generate_conversation_turn( self, role: str, conversations: List["ChatType"], active_indices: List[int] - ) -> Tuple[List["ChatType"], List[int]]: + ) -> Tuple[List["ChatType"], List[int], "LLMStatistics"]: # Generate an output for the conversations that are still active (no previous `None`s) outputs = self.llm.generate( inputs=[conversations[idx] for idx in active_indices], num_generations=1, **self.llm.generation_kwargs, # type: ignore ) + # Extract the single message from the conversation + messages = [output["generations"][0] for output in outputs] + statistics = [output["statistics"] for output in outputs] active_conversations = [conversations[idx] for idx in active_indices] updated_conversations = self._append_messages_to_conversations( role=role, - messages=[output[0] for output in outputs], + messages=messages, conversations=active_conversations, ) @@ -264,10 +280,10 @@ def _generate_conversation_turn( conversations[idx] = conv new_active_indices = [ - idx for idx, output in zip(active_indices, outputs) if output[0] is not None + idx for idx, output in zip(active_indices, outputs) if output is not None ] - return conversations, new_active_indices + return conversations, new_active_indices, statistics def _generate_multi_turn_conversation( self, inputs: List[Dict[str, Any]] @@ -278,30 +294,45 @@ def _generate_multi_turn_conversation( # Keep track of the active conversations, as it could happen that for some conversation # we can't generate the next turn because the `LLM` returned `None`. active_indices = list(range(len(conversations))) - + stats = [] for i in range(self.n_turns): # type: ignore if not active_indices: break # Generate user message - conversations, active_indices = self._generate_conversation_turn( - role="user", conversations=conversations, active_indices=active_indices + conversations, active_indices, statistics_user = ( + self._generate_conversation_turn( + role="user", + conversations=conversations, + active_indices=active_indices, + ) ) if i == self.n_turns - 1 and self.end_with_user: # type: ignore + statistics = merge_dicts(*[statistics_user]) + stats.append(statistics) break if not active_indices: break # Generate assistant message - conversations, active_indices = self._generate_conversation_turn( - role="assistant", - conversations=conversations, - active_indices=active_indices, + conversations, active_indices, statistics_assistant = ( + self._generate_conversation_turn( + role="assistant", + conversations=conversations, + active_indices=active_indices, + ) ) + # Merge the statistics of the user and assistant messages to have the same shape as the conversations + statistics = merge_dicts(*[statistics_user, statistics_assistant]) + stats.append(statistics) - return self._prepare_conversation_outputs(conversations, system_prompt_keys) + # Merge the dicts again at the conversation level + stats = merge_dicts(*stats) + return self._prepare_conversation_outputs( + conversations, system_prompt_keys + ), stats def _generate_with_pre_query_template( self, inputs: List[Dict[str, Any]] @@ -314,16 +345,22 @@ def _generate_with_pre_query_template( Returns: The list of generated conversations. """ - outputs = ( + outputs, statistics = ( self._generate_instruction(inputs) if self.only_instruction else self._generate_multi_turn_conversation(inputs) ) - - return [ - {**input, **output, "model_name": self.llm.model_name} - for input, output in zip(inputs, outputs) - ] + generations = [] + for input, output, stats in zip(inputs, outputs, statistics): + generation = { + **input, + **output, + "model_name": self.llm.model_name, + } + if not self.only_instruction: + generation["distilabel_metadata"] = {f"statistics_{self.name}": stats} + generations.append(generation) + return generations class Magpie(Task, MagpieBase): From 241d899ed55c3ffe24992c388cdf3c17866ba7cb Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 13:22:46 +0200 Subject: [PATCH 26/35] Update tests with the llm generate output format --- .../steps/tasks/evol_quality/test_base.py | 4 +- tests/unit/steps/tasks/test_base.py | 68 ++++++++++++++----- tests/unit/steps/tasks/test_decorator.py | 2 +- .../tasks/test_improving_text_embeddings.py | 20 +++--- .../tasks/test_instruction_backtranslation.py | 2 +- .../steps/tasks/test_structured_generation.py | 2 +- .../unit/steps/tasks/test_text_generation.py | 4 +- tests/unit/steps/tasks/test_ultrafeedback.py | 2 +- 8 files changed, 70 insertions(+), 34 deletions(-) diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py index 586b6b1f99..10d145c357 100644 --- a/tests/unit/steps/tasks/evol_quality/test_base.py +++ b/tests/unit/steps/tasks/evol_quality/test_base.py @@ -61,7 +61,7 @@ def test_process(self, dummy_llm: LLM) -> None: "evolved_response": "output", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_task": { "input_tokens": [12, 12], "output_tokens": [12, 12], } @@ -88,7 +88,7 @@ def test_process_store_evolutions(self, dummy_llm: LLM) -> None: "evolved_responses": ["output", "output"], "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_task": { "input_tokens": [12, 12], "output_tokens": [12, 12], } diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index d0772eb850..ab48a79b09 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -113,7 +113,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -128,7 +131,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_0", "role": "user"}, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -143,7 +149,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -158,7 +167,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_1", "role": "user"}, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -173,7 +185,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, { @@ -188,7 +203,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: {"content": "", "role": "system"}, {"content": "test_2", "role": "user"}, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, }, ], @@ -223,7 +241,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, { "raw_output_task": "output", @@ -237,7 +258,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, ], }, @@ -263,7 +287,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, { "raw_output_task": "output", @@ -277,7 +304,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, ], }, @@ -303,7 +333,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, { "raw_output_task": "output", @@ -317,7 +350,10 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None: "role": "user", }, ], - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": { + "input_tokens": 12, + "output_tokens": 12, + }, }, ], }, @@ -383,7 +419,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -405,7 +441,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -427,7 +463,7 @@ def test_process_overriding_inputs(self) -> None: }, ], "raw_output_task": "output", - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, "info_from_input": "info", "instruction": "instruction that won't be used but overriden by input mapping", @@ -647,4 +683,4 @@ def test_add_raw_input_and_or_output( assert ( "raw_input_dummy_task_0" in result[0]["distilabel_metadata"].keys() ) - assert "statistics" in result[0]["distilabel_metadata"].keys() + assert "statistics_dummy_task_0" in result[0]["distilabel_metadata"].keys() diff --git a/tests/unit/steps/tasks/test_decorator.py b/tests/unit/steps/tasks/test_decorator.py index 0f53d6d235..2280779799 100644 --- a/tests/unit/steps/tasks/test_decorator.py +++ b/tests/unit/steps/tasks/test_decorator.py @@ -195,7 +195,7 @@ def MyTask( }, ], "raw_output_my_task_0": "output 0", - "statistics": { + "statistics_my_task_0": { "input_tokens": 12, "output_tokens": 12, }, diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py index c73ba86107..59df057c65 100644 --- a/tests/unit/steps/tasks/test_improving_text_embeddings.py +++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py @@ -88,7 +88,7 @@ def test_process(self, category: str, flatten_tasks: bool) -> None: "tasks": ["A", "B", "C"], "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_embedding_task_generator": { "input_tokens": 12, "output_tokens": 12, } @@ -104,7 +104,7 @@ def test_process(self, category: str, flatten_tasks: bool) -> None: "task": "A", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_embedding_task_generator": { "input_tokens": 12, "output_tokens": 12, } @@ -114,7 +114,7 @@ def test_process(self, category: str, flatten_tasks: bool) -> None: "task": "B", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_embedding_task_generator": { "input_tokens": 12, "output_tokens": 12, } @@ -124,7 +124,7 @@ def test_process(self, category: str, flatten_tasks: bool) -> None: "task": "C", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_embedding_task_generator": { "input_tokens": 12, "output_tokens": 12, } @@ -187,7 +187,7 @@ def test_process(self) -> None: "S3": "C", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_bitext_retrieval_generator": { "input_tokens": 12, "output_tokens": 12, } @@ -261,7 +261,7 @@ def test_process(self) -> None: "S3": "C", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_monolingual_triplet_generator": { "input_tokens": 12, "output_tokens": 12, } @@ -322,7 +322,7 @@ def test_process(self) -> None: "positive_document": "B", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_generate_long_text_matching_data": { "input_tokens": 12, "output_tokens": 12, } @@ -363,7 +363,7 @@ def test_process(self) -> None: "positive_document": "B", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_generate_short_text_matching_data": { "input_tokens": 12, "output_tokens": 12, } @@ -431,7 +431,7 @@ def test_process(self) -> None: "misleading_label": "C", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_generate_text_classification_data": { "input_tokens": 12, "output_tokens": 12, } @@ -514,7 +514,7 @@ def test_process(self) -> None: "hard_negative_document": "C", "model_name": "test", "distilabel_metadata": { - "statistics": { + "statistics_generate_text_retrieval_data": { "input_tokens": 12, "output_tokens": 12, } diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py index ef34bca511..b108f0beb9 100644 --- a/tests/unit/steps/tasks/test_instruction_backtranslation.py +++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py @@ -97,7 +97,7 @@ def test_process(self) -> None: "model_name": "instruction-backtranslation-model", "distilabel_metadata": { "raw_output_instruction-backtranslation": "This is the reason. Score: 1", - "statistics": { + "statistics_instruction-backtranslation": { "input_tokens": 12, "output_tokens": 12, }, diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py index 3c85caa47b..4d1288d03d 100644 --- a/tests/unit/steps/tasks/test_structured_generation.py +++ b/tests/unit/steps/tasks/test_structured_generation.py @@ -133,7 +133,7 @@ def test_process(self) -> None: "model_name": "test", "distilabel_metadata": { "raw_output_task": '{"test": "output"}', - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, } ] diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index 7262f2c408..ad9b690430 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -103,7 +103,7 @@ def test_process(self) -> None: "model_name": "test", "distilabel_metadata": { "raw_output_task": "output", - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, } ] @@ -233,7 +233,7 @@ def test_process(self) -> None: "model_name": "test", "distilabel_metadata": { "raw_output_task": "output", - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, } ] diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py index 1adf58fe94..3875e3d4cf 100644 --- a/tests/unit/steps/tasks/test_ultrafeedback.py +++ b/tests/unit/steps/tasks/test_ultrafeedback.py @@ -99,7 +99,7 @@ def test_process_with_complex_aspect(self) -> None: "model_name": "ultrafeedback-model", "distilabel_metadata": { "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text", - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_task": {"input_tokens": 12, "output_tokens": 12}, }, } ] From 40e408dcd94fc06be7fbd6ff8856939fb784d23e Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 15:54:42 +0200 Subject: [PATCH 27/35] Fix pending tests --- .../steps/clustering/text_clustering.py | 1 + src/distilabel/steps/tasks/base.py | 1 - .../unit/steps/clustering/test_text_clustering.py | 15 ++++++++++----- .../tasks/structured_outputs/test_outlines.py | 5 ++--- tests/unit/steps/tasks/test_ultrafeedback.py | 10 ++++++++-- 5 files changed, 21 insertions(+), 11 deletions(-) diff --git a/src/distilabel/steps/clustering/text_clustering.py b/src/distilabel/steps/clustering/text_clustering.py index 7e640bf5c1..aa2e267db5 100644 --- a/src/distilabel/steps/clustering/text_clustering.py +++ b/src/distilabel/steps/clustering/text_clustering.py @@ -312,6 +312,7 @@ def process(self, inputs: StepInput) -> "StepOutput": self._logger.info(f"📦 Processing internal batch of inputs {i}...") results = super().process(batched_inputs) for result in next(results): # Extract the elements from the generator + print("INTERMEDIATE RESULTS", result) cluster_summaries[result["__LABEL"]] = result["labels"] # Assign the labels to each text diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py index 5fa9865ecf..a8034c9ec0 100644 --- a/src/distilabel/steps/tasks/base.py +++ b/src/distilabel/steps/tasks/base.py @@ -434,7 +434,6 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore num_generations=self.num_generations, # type: ignore **self.llm.get_generation_kwargs(), # type: ignore ) - task_outputs = [] for input, input_outputs in zip(inputs, outputs): formatted_outputs = self._format_outputs(input_outputs, input) diff --git a/tests/unit/steps/clustering/test_text_clustering.py b/tests/unit/steps/clustering/test_text_clustering.py index 4b2da96d40..e2f555eeb2 100644 --- a/tests/unit/steps/clustering/test_text_clustering.py +++ b/tests/unit/steps/clustering/test_text_clustering.py @@ -32,11 +32,16 @@ async def agenerate( # type: ignore self, input: "FormattedInput", num_generations: int = 1 ) -> "GenerateOutput": if self.n == 1: - return [json.dumps({"labels": "label"}) for _ in range(num_generations)] - return [ - json.dumps({"labels": ["label" for _ in range(self.n)]}) - for _ in range(self.n) - ] + text = json.dumps({"labels": "label"}) + else: + text = json.dumps({"labels": ["label" for _ in range(self.n)]}) + return { + "generations": [text] * num_generations, + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } class TestTextClustering: diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py index d2be053aa5..1b00a6d527 100644 --- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py +++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py @@ -19,7 +19,6 @@ from distilabel.llms.huggingface.transformers import TransformersLLM from distilabel.steps.tasks.structured_outputs.outlines import ( - # StructuredOutputType, model_to_schema, ) from distilabel.steps.tasks.typing import OutlinesStructuredOutputType @@ -138,8 +137,8 @@ def test_generation( ] result = llm.generate(prompt, max_new_tokens=30) assert isinstance(result, list) - assert isinstance(result[0], list) - assert isinstance(result[0][0], str) + assert isinstance(result[0], dict) + assert "generations" in result[0] and "statistics" in result[0] @pytest.mark.parametrize( "format, schema, dump", diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py index 3875e3d4cf..65271e75ec 100644 --- a/tests/unit/steps/tasks/test_ultrafeedback.py +++ b/tests/unit/steps/tasks/test_ultrafeedback.py @@ -71,7 +71,10 @@ def test_process_with_simple_aspect(self) -> None: "model_name": "ultrafeedback-model", "distilabel_metadata": { "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text", - "statistics": {"input_tokens": 12, "output_tokens": 12}, + "statistics_ultrafeedback": { + "input_tokens": 12, + "output_tokens": 12, + }, }, } ] @@ -99,7 +102,10 @@ def test_process_with_complex_aspect(self) -> None: "model_name": "ultrafeedback-model", "distilabel_metadata": { "raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text", - "statistics_task": {"input_tokens": 12, "output_tokens": 12}, + "statistics_ultrafeedback": { + "input_tokens": 12, + "output_tokens": 12, + }, }, } ] From 6c2e1fd488123931b41a90cc6fe41bf4630d519c Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 16:50:46 +0200 Subject: [PATCH 28/35] Fix test failing with vllm version upgrade --- tests/unit/llms/test_vllm.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/unit/llms/test_vllm.py b/tests/unit/llms/test_vllm.py index 96714e1ba3..e0616837f1 100644 --- a/tests/unit/llms/test_vllm.py +++ b/tests/unit/llms/test_vllm.py @@ -22,6 +22,7 @@ from openai.types.completion_choice import CompletionChoice from openai.types.completion_usage import CompletionUsage from pydantic import BaseModel +from transformers import AutoTokenizer from distilabel.llms import vLLM from distilabel.llms.vllm import ClientvLLM @@ -101,10 +102,10 @@ class Animal(BaseModel): ] -# Just a mock to avoid loading the model class DummyTokenizer: # chat_template = None chat_template = "template" + vocabulary = {"I'm": 1, "fine": 2, "thank": 3, "you": 4, "sir": 5} def __init__(self) -> None: pass @@ -115,6 +116,12 @@ def apply_chat_template(self, input, **kwargs): def encode(self, text: str): return [1, 2, 3, 4, 5] + def convert_token_to_string(self, token: str) -> str: + return "token" + + def get_vocab(self): + return self.vocabulary + class TestvLLM: @pytest.mark.parametrize("multi_structured_output", (False, True)) @@ -148,8 +155,11 @@ def test_generate( expected_result: List[Dict[str, Any]], ) -> None: llm = vLLM(model="dummy") - llm._tokenizer = DummyTokenizer() + tokenizer = AutoTokenizer.from_pretrained( + "distilabel-internal-testing/tiny-random-mistral" + ) vllm_mock = mock.MagicMock() + vllm_mock.get_tokenizer = mock.MagicMock(return_value=tokenizer) # mock the import by hacking sys.modules # https://stackoverflow.com/questions/60919705/how-to-mock-in-a-python-unittest-a-library-not-installed-locally import sys @@ -192,8 +202,10 @@ def test_generate( }, ], { - "format": "json", - "schema": Character.model_json_schema(), + # "format": "json", + "format": "regex", + "schema": r".*", + # "schema": Character.model_json_schema(), }, ) ] From d28a7983312427b1293b06ffae3da85b23ac295a Mon Sep 17 00:00:00 2001 From: plaguss Date: Thu, 24 Oct 2024 17:00:14 +0200 Subject: [PATCH 29/35] Another fix including tokenizer for our llm to work and to avoid outlines complaining --- tests/unit/llms/test_vllm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/llms/test_vllm.py b/tests/unit/llms/test_vllm.py index e0616837f1..9c76b36449 100644 --- a/tests/unit/llms/test_vllm.py +++ b/tests/unit/llms/test_vllm.py @@ -158,6 +158,7 @@ def test_generate( tokenizer = AutoTokenizer.from_pretrained( "distilabel-internal-testing/tiny-random-mistral" ) + llm._tokenizer = DummyTokenizer() vllm_mock = mock.MagicMock() vllm_mock.get_tokenizer = mock.MagicMock(return_value=tokenizer) # mock the import by hacking sys.modules From 1bc28ba86221eb04097c07078944ec23f8f4821d Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 25 Oct 2024 09:41:09 +0200 Subject: [PATCH 30/35] Fix dummy offline batch generation --- tests/integration/test_offline_batch_generation.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_offline_batch_generation.py b/tests/integration/test_offline_batch_generation.py index a9fe880ff7..309e2c62e4 100644 --- a/tests/integration/test_offline_batch_generation.py +++ b/tests/integration/test_offline_batch_generation.py @@ -51,8 +51,15 @@ def offline_batch_generate( raise DistilabelOfflineBatchGenerationNotFinishedException( jobs_ids=self.jobs_ids # type: ignore ) - - return [["output" for _ in range(num_generations)]] + return [ + { + "generations": [f"output {i}" for i in range(num_generations)], + "statistics": { + "input_tokens": [12] * num_generations, + "output_tokens": [12] * num_generations, + }, + } + ] * len(inputs) def test_offline_batch_generation() -> None: From edbea28419143801d8a090c60856864c000e0370 Mon Sep 17 00:00:00 2001 From: plaguss Date: Mon, 28 Oct 2024 09:46:53 +0100 Subject: [PATCH 31/35] Compute tokens using the tokenizer if available --- .../models/llms/huggingface/inference_endpoints.py | 11 +++++++---- .../llms/huggingface/test_inference_endpoints.py | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/src/distilabel/models/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py index b2767c4714..c60199452b 100644 --- a/src/distilabel/models/llms/huggingface/inference_endpoints.py +++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py @@ -32,7 +32,7 @@ from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms.base import AsyncLLM from distilabel.models.llms.typing import GenerateOutput -from distilabel.models.llms.utils import prepare_output +from distilabel.models.llms.utils import compute_tokens, prepare_output from distilabel.models.mixins.magpie import MagpieChatTemplateMixin from distilabel.steps.tasks.typing import ( FormattedInput, @@ -423,11 +423,14 @@ async def _generate_with_text_generation( f"⚠️ Received no response using Inference Client (model: '{self.model_name}')." f" Finish reason was: {e}" ) - # NOTE: I cannot see the input tokens returned, and given that the model can be private, I cannot - # count them... + return prepare_output( [completion.generated_text], - input_tokens=[0], + input_tokens=[ + compute_tokens(self.prepare_input(input), self._tokenizer.encode) + if self._tokenizer + else 0 + ], output_tokens=[ completion.details.generated_tokens if completion.details else 0 ], diff --git a/tests/unit/models/llms/huggingface/test_inference_endpoints.py b/tests/unit/models/llms/huggingface/test_inference_endpoints.py index 1bac6dfcf6..874cd9a595 100644 --- a/tests/unit/models/llms/huggingface/test_inference_endpoints.py +++ b/tests/unit/models/llms/huggingface/test_inference_endpoints.py @@ -149,7 +149,7 @@ async def test_agenerate_with_text_generation( assert result == { "generations": ["Aenean hendrerit aliquam velit..."], "statistics": { - "input_tokens": [0], + "input_tokens": [31], "output_tokens": [66], }, } @@ -358,7 +358,7 @@ async def test_agenerate_with_structured_output( assert result == { "generations": ["Aenean hendrerit aliquam velit..."], "statistics": { - "input_tokens": [0], + "input_tokens": [31], "output_tokens": [66], }, } From e97f901d8f448f8614eabfa3470d8efffa3c2e4f Mon Sep 17 00:00:00 2001 From: plaguss Date: Mon, 28 Oct 2024 11:17:20 +0100 Subject: [PATCH 32/35] Update docs to include references to the new outputs of the LLMs including statistics --- .../sections/how_to_guides/basic/llm/index.md | 60 +++++++++++++++---- .../how_to_guides/basic/task/index.md | 43 ++++++++----- 2 files changed, 79 insertions(+), 24 deletions(-) diff --git a/docs/sections/how_to_guides/basic/llm/index.md b/docs/sections/how_to_guides/basic/llm/index.md index d5d5a37368..d715994cb4 100644 --- a/docs/sections/how_to_guides/basic/llm/index.md +++ b/docs/sections/how_to_guides/basic/llm/index.md @@ -7,7 +7,10 @@ LLM subclasses are designed to be used within a [Task][distilabel.steps.tasks.Ta ```python from distilabel.models import InferenceEndpointsLLM -llm = InferenceEndpointsLLM(model="meta-llama/Meta-Llama-3.1-70B-Instruct") +llm = InferenceEndpointsLLM( + model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", + tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct" +) llm.load() llm.generate_outputs( @@ -15,12 +18,34 @@ llm.generate_outputs( [{"role": "user", "content": "What's the capital of Spain?"}], ], ) -# "The capital of Spain is Madrid." +# [ +# { +# "generations": [ +# "The capital of Spain is Madrid." +# ], +# "statistics": { +# "input_tokens": [ +# 43 +# ], +# "output_tokens": [ +# 8 +# ] +# } +# } +# ] ``` -!!! NOTE +!!! Note Always call the `LLM.load` or `Task.load` method when using LLMs standalone or as part of a `Task`. If using a `Pipeline`, this is done automatically in `Pipeline.run()`. +!!! Tip "New in version 1.5.0" + Since version `1.5.0` the LLM output is a list of dictionaries (one per item in the `inputs`), + each containing `generations`, that reports the text returned by the `LLM`, and a `statistics` field that will store statistics related to the `LLM` generation. Initially, this will include + `input_tokens` and `output_tokens` when available, which will be obtained via the API when available, or if a tokenizer is available for the model used, using the tokenizer for the model. + This data will be moved by the corresponding `Task` during the pipeline processing and moved to `distilabel_metadata` so we can operate on this data if we want, like for example computing the number of tokens per dataset. + + To access to the previous result one just has to access to the generations in the resulting dictionary: `result[0]["generations"]`. + ### Offline Batch Generation By default, all `LLM`s will generate text in a synchronous manner i.e. send inputs using `generate_outputs` method that will get blocked until outputs are generated. There are some `LLM`s (such as [OpenAILLM][distilabel.models.llms.openai.OpenAILLM]) that implements what we denote as _offline batch generation_, which allows to send the inputs to the LLM-as-a-service which will generate the outputs asynchronously and give us a job id that we can use later to check the status and retrieve the generated outputs when they are ready. LLM-as-a-service platforms offers this feature as a way to save costs in exchange of waiting for the outputs to be generated. @@ -56,7 +81,8 @@ llm.generate_outputs( # (4) [{"role": "user", "content": "What's the capital of Spain?"}], ], ) -# "The capital of Spain is Madrid." +# [{'generations': ['The capital of Spain is Madrid.'], +# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}] ``` 1. At first the `jobs_ids` attribute is `None`. @@ -81,7 +107,8 @@ llm.generate_outputs( [{"role": "user", "content": "What's the capital of Spain?"}], ], ) -# "The capital of Spain is Madrid." +# [{'generations': ['The capital of Spain is Madrid.'], +# 'statistics': {'input_tokens': [13], 'output_tokens': [7]}}] ``` ### Within a Task @@ -92,20 +119,30 @@ Pass the LLM as an argument to the [`Task`][distilabel.steps.tasks.Task], and th from distilabel.models import OpenAILLM from distilabel.steps.tasks import TextGeneration -llm = OpenAILLM(model="gpt-4") +llm = OpenAILLM(model="gpt-4o-mini") task = TextGeneration(name="text_generation", llm=llm) task.load() next(task.process(inputs=[{"instruction": "What's the capital of Spain?"}])) -# [{'instruction': "What's the capital of Spain?", "generation": "The capital of Spain is Madrid."}] +# [{'instruction': "What's the capital of Spain?", +# 'generation': 'The capital of Spain is Madrid.', +# 'distilabel_metadata': {'raw_output_text_generation': 'The capital of Spain is Madrid.', +# 'raw_input_text_generation': [{'role': 'user', +# 'content': "What's the capital of Spain?"}], +# 'statistics_text_generation': {'input_tokens': 13, 'output_tokens': 7}}, +# 'model_name': 'gpt-4o-mini'}] ``` +!!! Note + As mentioned in *Working with LLMs* section, the generation of an LLM is automatically moved to `distilabel_metadata` to avoid interference with the common workflow, so the addition of the `statistics` it's an extra component available for the user, but nothing has to be changed in the + defined pipelines. + ### Runtime Parameters LLMs can have runtime parameters, such as `generation_kwargs`, provided via the `Pipeline.run()` method using the `params` argument. -!!! NOTE +!!! Note Runtime parameters can differ between LLM subclasses, caused by the different functionalities offered by the LLM providers. ```python @@ -122,7 +159,7 @@ with Pipeline(name="text-generation-pipeline") as pipeline: text_generation = TextGeneration( name="text_generation", - llm=OpenAILLM(model="gpt-4"), + llm=OpenAILLM(model="gpt-4o-mini"), ) load_dataset >> text_generation @@ -200,9 +237,12 @@ To create custom LLMs, subclass either [`LLM`][distilabel.models.llms.LLM] for s `generate` and `agenerate` keyword arguments (but `input` and `num_generations`) are considered as `RuntimeParameter`s, so a value can be passed to them via the `parameters` argument of the `Pipeline.run` method. -!!! NOTE +!!! Note To have the arguments of the `generate` and `agenerate` coerced to the expected types, the `validate_call` decorator is used, which will automatically coerce the arguments to the expected types, and raise an error if the types are not correct. This is specially useful when providing a value for an argument of `generate` or `agenerate` from the CLI, since the CLI will always provide the arguments as strings. +!!! Warning + Additional LLMs created in `distilabel` will have to take into account how the `statistics` are generated to properly include them in the LLM output. + ## Available LLMs [Our LLM gallery](../../../../components-gallery/llms/index.md) shows a list of the available LLMs that can be used within the `distilabel` library. diff --git a/docs/sections/how_to_guides/basic/task/index.md b/docs/sections/how_to_guides/basic/task/index.md index 7f1d8260e0..dd5de6f837 100644 --- a/docs/sections/how_to_guides/basic/task/index.md +++ b/docs/sections/how_to_guides/basic/task/index.md @@ -21,26 +21,35 @@ task.load() next(task.process([{"instruction": "What's the capital of Spain?"}])) # [ -# { -# 'instruction': "What's the capital of Spain?", -# 'generation': 'The capital of Spain is Madrid.', -# 'distilabel_metadata': { -# 'raw_output_text-generation': 'The capital of Spain is Madrid.', -# 'raw_input_text-generation': [ -# {'role': 'user', 'content': "What's the capital of Spain?"} -# ] -# }, -# 'model_name': 'meta-llama/Meta-Llama-3-70B-Instruct' -# } +# { +# "instruction": "What's the capital of Spain?", +# "generation": "The capital of Spain is Madrid.", +# "distilabel_metadata": { +# "raw_output_text-generation": "The capital of Spain is Madrid.", +# "raw_input_text-generation": [ +# { +# "role": "user", +# "content": "What's the capital of Spain?" +# } +# ], +# "statistics_text-generation": { # (1) +# "input_tokens": 18, +# "output_tokens": 8 +# } +# }, +# "model_name": "meta-llama/Meta-Llama-3.1-8B-Instruct" +# } # ] ``` -!!! NOTE +1. The `LLMs` will not only return the text but also a `statistics_{STEP_NAME}` field that will contain statistics related to the generation. If available, at least the input and output tokens will be returned. + +!!! Note The `Step.load()` always needs to be executed when being used as a standalone. Within a pipeline, this will be done automatically during pipeline execution. As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task adds a `generation` based on the `instruction`. -!!! Tip +!!! Tip "New in version 1.2.0" Since version `1.2.0`, we provide some metadata about the LLM call through `distilabel_metadata`. This can be disabled by setting the `add_raw_output` attribute to `False` when creating the task. Additionally, since version `1.4.0`, the formatted input can also be included, which can be helpful when testing @@ -57,9 +66,12 @@ As shown above, the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] ta ) ``` +!!! Tip "New in version 1.5.0" + Since version `1.5.0` `distilabel_metadata` includes a new `statistics` field out of the box. The generation from the LLM will not only contain the text, but also statistics associated with the text if available, like the input and output tokens. This field will be generated with `statistic_{STEP_NAME}` to avoid collisions between different steps in the pipeline, similar to how `raw_output_{STEP_NAME}` works. + ### Task.print -!!! Info +!!! Info "New in version 1.4.0" New since version `1.4.0`, [`Task.print`][distilabel.steps.tasks.base._Task.print] `Task.print` method. The `Tasks` include a handy method to show what the prompt formatted for an `LLM` would look like, let's see an example with [`UltraFeedback`][distilabel.steps.tasks.ultrafeedback.UltraFeedback], but it applies to any other `Task`. @@ -271,3 +283,6 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe # Format the `LLM` output here return {"output_field": output} ``` + +!!! Warning + Most `Tasks` reuse the `Task.process` method to process the generations, but if a new `Task` defines a custom `process` method, like happens for example with [`Magpie`][distilabel.steps.tasks.magpie.base.Magpie], one hast to deal with the `statistics` returned by the `LLM`. From 84cfe8f60ee8337e81416e25f8482cd5611315b3 Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 8 Nov 2024 16:28:40 +0100 Subject: [PATCH 33/35] Update template card to include the statistics table --- src/distilabel/utils/card/distilabel_template.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/distilabel/utils/card/distilabel_template.md b/src/distilabel/utils/card/distilabel_template.md index 38daa7f857..50b1f0234f 100644 --- a/src/distilabel/utils/card/distilabel_template.md +++ b/src/distilabel/utils/card/distilabel_template.md @@ -85,6 +85,15 @@ ds = load_dataset("{{ repo_id }}") {% endif %} +{% if statistics %} +## Dataset Statistics +{% for leaf_name, table in statistics.items() %} +* **Summary statistics**: `{{ leaf_name }}` +{{ table }} +{% endfor %} + +{% endif %} + {% if references %} ## References From 6644d1059f79c18941faba4560f9dde035f1682f Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 8 Nov 2024 16:40:57 +0100 Subject: [PATCH 34/35] Add draft with statistics table with token summary --- src/distilabel/distiset.py | 43 +++++++++++++++++++++++++++++++++++++ tests/unit/test_distiset.py | 42 ++++++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) diff --git a/src/distilabel/distiset.py b/src/distilabel/distiset.py index 8e52c667d3..3a7ce3adc6 100644 --- a/src/distilabel/distiset.py +++ b/src/distilabel/distiset.py @@ -207,6 +207,7 @@ def _get_card( include_script=include_script, filename_py=filename_py, artifacts=self._get_artifacts_metadata(), + statistics=self._get_dataset_statistics(), references=self.citations, ) @@ -238,6 +239,48 @@ def iterdir_ignore_hidden(path: Path) -> Generator[Path, None, None]: return dict(artifacts_metadata) + def _get_dataset_statistics(self) -> Dict[str, str]: + """Builds a dictionary with the statistics of the dataset. + + Returns: + Each key in the dict corresponds to a leaf step in the pipeline, and the value is a markdown + table with the statistics. + """ + + def token_count(row): + metadata = row["distilabel_metadata"] + for col, data in metadata.items(): + if col.startswith("statistics"): + row[f"input_tokens_{col}"] = data["input_tokens"] + row[f"output_tokens_{col}"] = data["output_tokens"] + return row + + def get_token_count(dataset): + # Here we are accessing to a DatasetDict, which can result in error if + # there's no "train" split. + filtered = ( + dataset["train"] + .select_columns("distilabel_metadata") + .map(token_count, num_proc=4) + ) + token_count_columns = [ + col + for col in filtered.column_names + if (col.startswith("input_tokens") or col.startswith("output_tokens")) + ] + select_stat_columns = ["mean", "std", "min", "max"] + df = filtered.select_columns(token_count_columns).to_pandas() + df_stats = df.describe().T[select_stat_columns] + df_stats["sum"] = df.sum(axis=0) + return df_stats.to_markdown() + + stats = {} + + for name, dataset in self.items(): + stats[name] = get_token_count(dataset) + + return stats + def _extract_readme_metadata( self, repo_id: str, token: Optional[str] ) -> Dict[str, Any]: diff --git a/tests/unit/test_distiset.py b/tests/unit/test_distiset.py index 1649a2ff18..f3190c5993 100644 --- a/tests/unit/test_distiset.py +++ b/tests/unit/test_distiset.py @@ -236,3 +236,45 @@ def test_dataset_card(self, distiset: Distiset) -> None: "size_categories": "n<1K", "tags": ["synthetic", "distilabel", "rlaif"], } + + def test_token_counter(self) -> None: + disti = Distiset( + { + "default": DatasetDict( + { + "train": Dataset.from_dict( + { + "instruction": [ + "Generate a list of answers and questions about the document.", + "other", + ], + "generation": ["some generated text", "other"], + "distilabel_metadata": [ + { + "raw_input_generation": [], # Irrelevant in this case + "raw_output_generation": "", # Irrelevant in this case + "statistics_generation": { + "input_tokens": 1882, + "output_tokens": 583, + }, + }, + { + "raw_input_generation": [], # Irrelevant in this case + "raw_output_generation": "", # Irrelevant in this case + "statistics_generation": { + "input_tokens": 1880, + "output_tokens": 581, + }, + }, + ], + } + ) + } + ), + } + ) + statistics_table = disti._get_dataset_statistics() + assert len(statistics_table) == 1 + assert statistics_table.keys() == {"default"} + stats_table = "| | mean | std | min | max | sum |\n|:------------------------------------|-------:|--------:|------:|------:|------:|\n| input_tokens_statistics_generation | 1881 | 1.41421 | 1880 | 1882 | 3762 |\n| output_tokens_statistics_generation | 582 | 1.41421 | 581 | 583 | 1164 |" + assert statistics_table["default"] == stats_table From 5741cd1dde9f34f6f3aa6c2f0498d75eb7d33042 Mon Sep 17 00:00:00 2001 From: plaguss Date: Fri, 8 Nov 2024 16:45:54 +0100 Subject: [PATCH 35/35] Fix jinja2 template --- src/distilabel/utils/card/distilabel_template.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/distilabel/utils/card/distilabel_template.md b/src/distilabel/utils/card/distilabel_template.md index 50b1f0234f..c45dad2762 100644 --- a/src/distilabel/utils/card/distilabel_template.md +++ b/src/distilabel/utils/card/distilabel_template.md @@ -87,8 +87,10 @@ ds = load_dataset("{{ repo_id }}") {% if statistics %} ## Dataset Statistics + {% for leaf_name, table in statistics.items() %} * **Summary statistics**: `{{ leaf_name }}` + {{ table }} {% endfor %}