Skip to content

Latest commit





Folders and files

Last commit message
Last commit date

parent directory


OminiControl Training 🛠️



  1. Environment
    conda create -n omini python=3.10
    conda activate omini
  2. Requirements
    pip install -r train/requirements.txt


  1. Download dataset Subject200K. (subject-driven generation)
    bash train/script/data_download/
  2. Download dataset text-to-image-2M. (spatial control task)
    bash train/script/data_download/
    Note: By default, only a few files are downloaded. You can modify to download additional datasets. Remember to update the config file to specify the training data accordingly.


Start training training

Config file path: ./train/config

Scripts path: ./train/script

  1. Subject-driven generation
    bash train/script/
  2. Spatial control task
    bash train/script/

Note: Detailed WanDB settings and GPU settings can be found in the script files and the config files.

Other spatial control tasks

This repository supports 5 spatial control tasks:

  1. Canny edge to image (canny)
  2. Image colorization (coloring)
  3. Image deblurring (deblurring)
  4. Depth map to image (depth)
  5. Image to depth map (depth_pred)
  6. Image inpainting (fill)
  7. Super resolution (sr)

You can modify the condition_type parameter in config file config/canny_512.yaml to switch between different tasks.

Customize your own task

You can customize your own task by constructing a new dataset and modifying the training code.

  1. Dataset :

    Construct a new dataset with the following format: (src/train/

    class MyDataset(Dataset):
        def __init__(self, ...):
        def __len__(self):
        def __getitem__(self, idx):
            return {
                "image": image,
                "condition": condition_img,
                "condition_type": "your_condition_type",
                "description": description,
                "position_delta": position_delta

    Note: For spatial control tasks, set the position_delta to be [0, 0]. For non-spatial control tasks, set position_delta to be [0, -condition_width // 16].

  2. Condition:

    Add a new condition type in the Condition class. (src/flux/

    condition_dict = {
        "your_condition_type": your_condition_id_number, # Add your condition type here
    if condition_type in [
        "your_condition_type", # Add your condition type here
  3. Test:

    Add a new test function for your task. (src/train/

    if self.condition_type == "your_condition_type":
        condition_img = (
            .resize((condition_size, condition_size))
        test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
  4. Import relevant dataset in the training script Update the file in the following section. (src/train/

     from .data import (
     # Initialize dataset and dataloader
     if training_config["dataset"]["type"] == "your_condition_type":

Hardware requirement

Note: Memory optimization (like dynamic T5 model loading) is pending implementation.


  • Hardware: 2x NVIDIA H100 GPUs
  • Memory: ~80GB GPU memory


  • Hardware: 1x NVIDIA L20 GPU
  • Memory: ~48GB GPU memory