- Environment
conda create -n omini python=3.10 conda activate omini
- Requirements
pip install -r train/requirements.txt
- Download dataset Subject200K. (subject-driven generation)
bash train/script/data_download/data_download1.sh
- Download dataset text-to-image-2M. (spatial control task)
Note: By default, only a few files are downloaded. You can modify
bash train/script/data_download/data_download2.sh
data_download2.sh
to download additional datasets. Remember to update the config file to specify the training data accordingly.
Config file path: ./train/config
Scripts path: ./train/script
- Subject-driven generation
bash train/script/train_subject.sh
- Spatial control task
bash train/script/train_canny.sh
Note: Detailed WanDB settings and GPU settings can be found in the script files and the config files.
This repository supports 5 spatial control tasks:
- Canny edge to image (
canny
) - Image colorization (
coloring
) - Image deblurring (
deblurring
) - Depth map to image (
depth
) - Image to depth map (
depth_pred
) - Image inpainting (
fill
) - Super resolution (
sr
)
You can modify the condition_type
parameter in config file config/canny_512.yaml
to switch between different tasks.
You can customize your own task by constructing a new dataset and modifying the training code.
Instructions
-
Dataset :
Construct a new dataset with the following format: (
src/train/data.py
)class MyDataset(Dataset): def __init__(self, ...): ... def __len__(self): ... def __getitem__(self, idx): ... return { "image": image, "condition": condition_img, "condition_type": "your_condition_type", "description": description, "position_delta": position_delta }
Note: For spatial control tasks, set the
position_delta
to be[0, 0]
. For non-spatial control tasks, setposition_delta
to be[0, -condition_width // 16]
. -
Condition:
Add a new condition type in the
Condition
class. (src/flux/condition.py
)condition_dict = { ... "your_condition_type": your_condition_id_number, # Add your condition type here } ... if condition_type in [ ... "your_condition_type", # Add your condition type here ]: ...
-
Test:
Add a new test function for your task. (
src/train/callbacks.py
)if self.condition_type == "your_condition_type": condition_img = ( Image.open("images/vase.jpg") .resize((condition_size, condition_size)) .convert("RGB") ) ... test_list.append((condition_img, [0, 0], "A beautiful vase on a table."))
-
Import relevant dataset in the training script Update the file in the following section. (
src/train/train.py
)from .data import ( ImageConditionDataset, Subject200KDateset, MyDataset ) ... # Initialize dataset and dataloader if training_config["dataset"]["type"] == "your_condition_type": ...
Note: Memory optimization (like dynamic T5 model loading) is pending implementation.
- Hardware: 2x NVIDIA H100 GPUs
- Memory: ~80GB GPU memory