Simple re-implementation of inference-time scaling Flux.1-Dev as introduced in Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps by Ma et al. We implement the random search strategy to scale the inference compute budget.
Updates
🔥 16/02/2025: Support for batched image generation has been added in this PR. It speeds up the total time but consumes more memory.
🔥 15/02/2025: Support for structured generation with Qwen2.5 has been added (using outlines
and pydantic
) in this PR.
🔥 15/02/2025: Support to load other pipelines has been added in this PR! Result section has been updated, too.
Make sure to install the dependencies: pip install -r requirements
. The codebase was tested using a single H100 and two H100s (both 80GB variants).
By default, we use Gemini 2.0 Flash as the verifier (you can use Qwen2.5, too). This requires two things:
Now, fire up:
GEMINI_API_KEY=... python main.py --prompt="a tiny astronaut hatching from an egg on the moon" --num_prompts=None
If you want to use from the data-is-better-together/open-image-preferences-v1-binarized dataset, you can just run:
GEMINI_API_KEY=... python main.py
After this is done executing, you should expect a folder named output
with the following structure:
Click to expand
output/flux.1-dev/gemini/overall_score/20250215_141308$ tree
.
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
├── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
└── prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]
Each JSON file should look like so:
Click to expand
{
"prompt": "Photo of an athlete cat explaining it\u2019s latest scandal at a press conference to journalists.",
"search_round": 4,
"num_noises": 16,
"best_noise_seed": 1940263961,
"best_score": {
"explanation": "The image excels in accuracy, visual quality, and originality, with minor deductions for thematic resonance. Overall, it's a well-executed and imaginative response to the prompt.",
"score": 9.0
},
"choice_of_metric": "overall_score",
"best_img_path": "output/flux.1-dev/gemini/overall_score/20250216_135414/prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]"
}
To limit the number of prompts, specify --num_prompts
. By default, we use 2 prompts. Specify "--num_prompts=all" to use all.
The output directory should also contain a config.json
, looking like so:
Click to expand
{
"max_new_tokens": 300,
"use_low_gpu_vram": false,
"choice_of_metric": "overall_score",
"verifier_to_use": "gemini",
"torch_dtype": "bf16",
"height": 1024,
"width": 1024,
"max_sequence_length": 512,
"guidance_scale": 3.5,
"num_inference_steps": 50,
"pipeline_config_path": "configs/flux.1_dev.json",
"search_rounds": 4,
"prompt": "an anime illustration of a wiener schnitzel",
"num_prompts": null
}
Note
max_new_tokens
arg is ignored when using Gemini.
Once the results are generated, process the results by running:
python process_results.py --path=path_to_the_output_dir
This should output a collage of the best images generated in each search round, grouped by the same prompt.
By default, the --batch_size_for_img_gen
is set to 1. To speed up the process (at the expense of more memory),
this number can be increased.
Experiment configurations are provided through the --pipeline_config_path
arg which points to a JSON file. The structure of such JSON files should look like so:
{
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
"torch_dtype": "bf16",
"pipeline_call_args": {
"height": 1024,
"width": 1024,
"max_sequence_length": 512,
"guidance_scale": 3.5,
"num_inference_steps": 50
},
"verifier_args": {
"name": "gemini",
"max_new_tokens": 800,
"choice_of_metric": "overall_score"
},
"search_args": {
"search_method": "random",
"search_rounds": 4
}
}
This lets us control the pipeline call arguments, the verifier, and the search process.
This is controlled via the --pipeline_config_path
CLI args. By default, it uses configs/flux.1_dev.json
. You can either modify this one or create your own JSON file to experiment with different pipelines. We provide some predefined configs for Flux.1-Dev, PixArt-Sigma, SDXL, and SD v1.5 in the configs
directory.
The above-mentioned pipelines are already supported. To add your own, you need to make modifications to:
By default, we use 4 search_rounds
and start with a noise pool size of 2. Each search round scales up the pool size like so: 2 ** current_seach_round
(with indexing starting from 1). This is where the "scale" in inference-time scaling comes from. You can increase the compute budget by specifying a larger search_rounds
in the config file.
For each search round, we serialize the images and best datapoint (characterized by the best eval score) in a JSON file.
For other supported CLI args, run python main.py -h
.
If you don't want to use Gemini, you can use Qwen2.5 as an option. Simply specify "name"=qwen
under the "verifier_args"
of the config. Below is a complete command that uses SDXL-base:
python main.py \
--pipeline_config_path="configs/sdxl.json" \
--prompt="Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists."
Sample search JSON
{
"prompt": "Photo of an athlete cat explaining it\u2019s latest scandal at a press conference to journalists.",
"search_round": 6,
"num_noises": 64,
"best_noise_seed": 1937268448,
"best_score": {
"explanation": "Overall, the image demonstrates a high level of accuracy, creativity, and theme consistency while maintaining a high visual quality and coherence within the depicted scenario. The humor and surprise value are significant, contributing to above-average scoring.",
"score": 9.0
},
"choice_of_metric": "overall_score",
"best_img_path": "output/sdxl-base/qwen/overall_score/20250216_141140/prompt@Photo_of_an_athlete_cat_explaining_it_s_latest_scandal_at_a_press_conference_to_journ_hash@b9094b65_i@[email protected]"
}
Results
Result |
---|
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists. |
Important
This setup was tested on 2 H100s. If you want to do this on a single GPU, specify --use_low_gpu_vram
.
You can also bring in your own verifier by implementing a so-called Verifier
class following the structure of either of GeminiVerifier
or QwenVerifier
. You will then have to make adjustments to the following places:
By default, we use "overall_score" as the metric to obtain the best samples in each search round. You can change it by specifying --choice_of_metric
. Supported values are:
- "accuracy_to_prompt"
- "creativity_and_originality"
- "visual_quality_and_realism"
- "consistency_and_cohesion"
- "emotional_or_thematic_resonance"
- "overall_score"
If you're experimenting with a new verifier, you can relax these choices.
The verifier prompt that is used during grading/verification is specified in this file. The prompt is a slightly modified version of the one specified in the Figure 16 of the paper (Inference-Time Scaling for Diffusion Models beyond Scaling Denoising Steps). You are welcome to experiment with a different prompt.
Click to expand
Both searches were performed with "overall_score" as the metric. Below is example, presenting a comparison between the outputs of different metrics -- "overall_score" vs. "emotional_or_thematic_resonance" for the prompt: "a tiny astronaut hatching from an egg on the moon":
PixArt-Sigma
Result |
---|
A person playing saxophone. |
Photo of an athlete cat explaining it’s latest scandal at a press conference to journalists. |