Skip to content

sayakpaul/tt-scale-flux

Repository files navigation

tt-scale-flux

Simple re-implementation of inference-time scaling Flux.1-Dev as introduced in Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps by Ma et al. We implement the random search strategy to scale the inference compute budget.

Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists.

Updates

🔥 16/02/2025: Support for batched image generation has been added in this PR. It speeds up the total time but consumes more memory.

🔥 15/02/2025: Support for structured generation with Qwen2.5 has been added (using outlines and pydantic) in this PR.

🔥 15/02/2025: Support to load other pipelines has been added in this PR! Result section has been updated, too.

Getting started

Make sure to install the dependencies: pip install -r requirements. The codebase was tested using a single H100 and two H100s (both 80GB variants).

By default, we use Gemini 2.0 Flash as the verifier (you can use Qwen2.5, too). This requires two things:

  • GEMINI_API_KEY (obtain it from here).
  • google-genai Python library.

Now, fire up:

GEMINI_API_KEY=... python main.py --prompt="a tiny astronaut hatching from an egg on the moon" --num_prompts=None

If you want to use from the data-is-better-together/open-image-preferences-v1-binarized dataset, you can just run:

GEMINI_API_KEY=... python main.py

After this is done executing, you should expect a folder named output with the following structure:

Click to expand
output/flux.1-dev/gemini/overall_score/20250215_141308$ tree 
.
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
└── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]

Each JSON file should look like so:

Click to expand
{
    "prompt": "Photo of an athlete cat explaining it\u2019s latest scandal at a press conference to journalists.",
    "search_round": 4,
    "num_noises": 16,
    "best_noise_seed": 1940263961,
    "best_score": {
        "explanation": "The image excels in accuracy, visual quality, and originality, with minor deductions for thematic resonance. Overall, it's a well-executed and imaginative response to the prompt.",
        "score": 9.0
    },
    "choice_of_metric": "overall_score",
    "best_img_path": "output/flux.1-dev/gemini/overall_score/20250216_135414/prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]"
}

To limit the number of prompts, specify --num_prompts. By default, we use 2 prompts. Specify "--num_prompts=all" to use all.

The output directory should also contain a config.json, looking like so:

Click to expand
{
  "max_new_tokens": 300,
  "use_low_gpu_vram": false,
  "choice_of_metric": "overall_score",
  "verifier_to_use": "gemini",
  "torch_dtype": "bf16",
  "height": 1024,
  "width": 1024,
  "max_sequence_length": 512,
  "guidance_scale": 3.5,
  "num_inference_steps": 50,
  "pipeline_config_path": "configs/flux.1_dev.json",
  "search_rounds": 4,
  "prompt": "an anime illustration of a wiener schnitzel",
  "num_prompts": null
}

Note

max_new_tokens arg is ignored when using Gemini.

Once the results are generated, process the results by running:

python process_results.py --path=path_to_the_output_dir

This should output a collage of the best images generated in each search round, grouped by the same prompt.

By default, the --batch_size_for_img_gen is set to 1. To speed up the process (at the expense of more memory), this number can be increased.

Controlling experiment configurations

Experiment configurations are provided through the --pipeline_config_path arg which points to a JSON file. The structure of such JSON files should look like so:

{
    "pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
    "torch_dtype": "bf16",
    "pipeline_call_args": {
        "height": 1024,
        "width": 1024,
        "max_sequence_length": 512,
        "guidance_scale": 3.5,
        "num_inference_steps": 50
    },
    "verifier_args": {
        "name": "gemini", 
        "max_new_tokens": 800,
        "choice_of_metric": "overall_score"
    },
    "search_args": {
        "search_method": "random",
        "search_rounds": 4
    }
}

This lets us control the pipeline call arguments, the verifier, and the search process.

Controlling the pipeline checkpoint and __call__() args

This is controlled via the --pipeline_config_path CLI args. By default, it uses configs/flux.1_dev.json. You can either modify this one or create your own JSON file to experiment with different pipelines. We provide some predefined configs for Flux.1-Dev, PixArt-Sigma, SDXL, and SD v1.5 in the configs directory.

The above-mentioned pipelines are already supported. To add your own, you need to make modifications to:

Controlling the "scale"

By default, we use 4 search_rounds and start with a noise pool size of 2. Each search round scales up the pool size like so: 2 ** current_seach_round (with indexing starting from 1). This is where the "scale" in inference-time scaling comes from. You can increase the compute budget by specifying a larger search_rounds in the config file.

For each search round, we serialize the images and best datapoint (characterized by the best eval score) in a JSON file.

For other supported CLI args, run python main.py -h.

Controlling the verifier

If you don't want to use Gemini, you can use Qwen2.5 as an option. Simply specify "name"=qwen under the "verifier_args" of the config. Below is a complete command that uses SDXL-base:

python main.py \
  --pipeline_config_path="configs/sdxl.json" \
  --prompt="Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists."
Sample search JSON
{
    "prompt": "Photo of an athlete cat explaining it\u2019s latest scandal at a press conference to journalists.",
    "search_round": 6,
    "num_noises": 64,
    "best_noise_seed": 1937268448,
    "best_score": {
        "explanation": "Overall, the image demonstrates a high level of accuracy, creativity, and theme consistency while maintaining a high visual quality and coherence within the depicted scenario. The humor and surprise value are significant, contributing to above-average scoring.",
        "score": 9.0
    },
    "choice_of_metric": "overall_score",
    "best_img_path": "output/sdxl-base/qwen/overall_score/20250216_141140/prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]"
}
  
Results
Result
scandal_cat
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists.
  

Important

This setup was tested on 2 H100s. If you want to do this on a single GPU, specify --use_low_gpu_vram.

You can also bring in your own verifier by implementing a so-called Verifier class following the structure of either of GeminiVerifier or QwenVerifier. You will then have to make adjustments to the following places:

By default, we use "overall_score" as the metric to obtain the best samples in each search round. You can change it by specifying --choice_of_metric. Supported values are:

  • "accuracy_to_prompt"
  • "creativity_and_originality"
  • "visual_quality_and_realism"
  • "consistency_and_cohesion"
  • "emotional_or_thematic_resonance"
  • "overall_score"

If you're experimenting with a new verifier, you can relax these choices.

The verifier prompt that is used during grading/verification is specified in this file. The prompt is a slightly modified version of the one specified in the Figure 16 of the paper (Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps). You are welcome to experiment with a different prompt.

More results

Click to expand
Result
Manga
a bustling manga street, devoid of vehicles, detailed with vibrant colors and dynamic
line work, characters in the background adding life and movement, under a soft golden
hour light, with rich textures and a lively atmosphere, high resolution, sharp focus
Alice
Alice in a vibrant, dreamlike digital painting inside the Nemo Nautilus submarine.
wiener_schnitzel
an anime illustration of a wiener schnitzel
  

Both searches were performed with "overall_score" as the metric. Below is example, presenting a comparison between the outputs of different metrics -- "overall_score" vs. "emotional_or_thematic_resonance" for the prompt: "a tiny astronaut hatching from an egg on the moon":

Click to expand
Metric Result
"overall_score" overall
"emotional_or_thematic_resonance" Alicet

Results from other models

PixArt-Sigma
Result
saxophone
A person playing saxophone.
scandal_cat
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists.

SD v1.5
Result
saxophone
a photo of an astronaut riding a horse on mars
  
SDXL-base
Result
scandal_cat
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists.
  

Acknowledgements

  • Thanks to Willis Ma for all the guidance and pair-coding.
  • Thanks to Hugging Face for supporting the compute.
  • Thanks to Google for providing Gemini credits.
  • Thanks a bunch to amitness for this PR.

About

Inference-time scaling of Flux beyond denoising steps.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages