Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Pipeline serialization/caching issue when including RoutingBatchFunction #1070

Open
liamcripwell opened this issue Nov 25, 2024 · 1 comment

Comments

@liamcripwell
Copy link

Describe the bug
I am experiencing an issue when trying to use a RoutingBatchFunction inside a pipeline. Specifically I am using sample_n_steps() as shown as an example here: https://distilabel.argilla.io/latest/api/pipeline/routing_batch_function/?h=routing#distilabel.pipeline.routing_batch_function.routing_batch_function

The pipeline initially runs without issue, but if I try to run it again it gives AttributeError: 'NoneType' object has no attribute 'name' stemming from distilabel/src/distilabel/pipeline/routing_batch_function.py:170 in dump (on develop). This seems to be a failure with serializing this step.

We are able to get around this temporarily by manually deleting the cache directory on disk, but the error continues to occur even when using use_cache=False in pipeline.run(). Is caching supposed to be required to some degree even when this is specified?

To Reproduce
Code to reproduce

from distilabel.llms import OpenAILLM
from distilabel.pipeline import Pipeline, sample_n_steps
from distilabel.steps import LoadDataFromHub, GroupColumns
from distilabel.steps.tasks import TextGeneration

random_routing_batch = sample_n_steps(2)


with Pipeline(name="routing-batch-function") as pipeline:
    load_dataset = LoadDataFromHub(
        name="load_dataset",
    )

    generations = []
    for llm in (
        OpenAILLM(model="gpt-4o"),
        OpenAILLM(model="gpt-4o-mini"),
    ):
        task = TextGeneration(
            name=f"text_generation_with_{llm.model_name}", 
            llm=llm,
            input_mappings={"instruction": "prompt"},
)
        generations.append(task)

    combine_columns = GroupColumns(columns=["generation", "model_name"])

    load_dataset >> random_routing_batch >> generations >> combine_columns


if __name__ == "__main__":
    distiset = pipeline.run(
        use_cache=False,
        parameters={
            "load_dataset": {
                "repo_id": "distilabel-internal-testing/instruction-dataset-mini",
                "num_examples": 3,
                "split": "test",
            },
        }
    )

Expected behaviour
For the serialization to handle this case, and/or for the caching to actually be skipped when specified. Perhaps I am missing something from best practices?

Desktop (please complete the following information):

  • Package version: built from source, happening on both main and develop branches
  • Python version: 3.11.9
@Abdelrhman-Wael
Copy link

facing same issue with routing_batch_function

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants