Skip to content

Latest commit

 

History

History

train

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

OminiControl Training 🛠️

Preparation

Setup

  1. Environment
    conda create -n omini python=3.10
    conda activate omini
  2. Requirements
    pip install -r train/requirements.txt

Dataset

  1. Download dataset Subject200K. (subject-driven generation)
    bash train/script/data_download/data_download1.sh
    
  2. Download dataset text-to-image-2M. (spatial control task)
    bash train/script/data_download/data_download2.sh
    
    Note: By default, only a few files are downloaded. You can modify data_download2.sh to download additional datasets. Remember to update the config file to specify the training data accordingly.

Training

Start training training

Config file path: ./train/config

Scripts path: ./train/script

  1. Subject-driven generation
    bash train/script/train_subject.sh
  2. Spatial control task
    bash train/script/train_canny.sh

Note: Detailed WanDB settings and GPU settings can be found in the script files and the config files.

Other spatial control tasks

This repository supports 5 spatial control tasks:

  1. Canny edge to image (canny)
  2. Image colorization (coloring)
  3. Image deblurring (deblurring)
  4. Depth map to image (depth)
  5. Image to depth map (depth_pred)
  6. Image inpainting (fill)
  7. Super resolution (sr)

You can modify the condition_type parameter in config file config/canny_512.yaml to switch between different tasks.

Customize your own task

You can customize your own task by constructing a new dataset and modifying the training code.

Instructions
  1. Dataset :

    Construct a new dataset with the following format: (src/train/data.py)

    class MyDataset(Dataset):
        def __init__(self, ...):
            ...
        def __len__(self):
            ...
        def __getitem__(self, idx):
            ...
            return {
                "image": image,
                "condition": condition_img,
                "condition_type": "your_condition_type",
                "description": description,
                "position_delta": position_delta
            }

    Note: For spatial control tasks, set the position_delta to be [0, 0]. For non-spatial control tasks, set position_delta to be [0, -condition_width // 16].

  2. Condition:

    Add a new condition type in the Condition class. (src/flux/condition.py)

    condition_dict = {
        ...
        "your_condition_type": your_condition_id_number, # Add your condition type here
    }
    ...
    if condition_type in [
        ...
        "your_condition_type", # Add your condition type here
    ]:
        ...
  3. Test:

    Add a new test function for your task. (src/train/callbacks.py)

    if self.condition_type == "your_condition_type":
        condition_img = (
            Image.open("images/vase.jpg")
            .resize((condition_size, condition_size))
            .convert("RGB")
        )
        ...
        test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
  4. Import relevant dataset in the training script Update the file in the following section. (src/train/train.py)

     from .data import (
         ImageConditionDataset,
         Subject200KDateset,
         MyDataset
     )
     ...
    
     # Initialize dataset and dataloader
     if training_config["dataset"]["type"] == "your_condition_type":
        ...

Hardware requirement

Note: Memory optimization (like dynamic T5 model loading) is pending implementation.

  • Hardware: 2x NVIDIA H100 GPUs
  • Memory: ~80GB GPU memory