Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds TrainDatasetAddMultiReso #66

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 84 additions & 33 deletions nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
import torch
from torchvision import transforms
Expand Down Expand Up @@ -144,17 +145,22 @@ def create_config(self, dataset_path, class_tokens, num_repeats):
}

return reg_subset,

class TrainDatasetAdd:

class TrainDatasetAddBase:
MIN_WIDTH = 64
MIN_HEIGHT = 64
DEF_WIDTH = 1024
DEF_HEIGHT = 1024

def __init__(self):
self.previous_dataset_signature = None
self.previous_dataset_signatures = set()

@classmethod
def INPUT_TYPES(s):
def base_input_types(s):
return {"required": {
"dataset_config": ("JSON",),
"width": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution width"}),
"height": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution height"}),
"width": None, # placeholder to preserve order
"height": None, # for backward compatibility
"batch_size": ("INT",{"min": 1, "default": 2, "tooltip": "Higher batch size uses more memory and generalizes the training more"}),
"dataset_path": ("STRING",{"multiline": True, "default": "", "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}),
"class_tokens": ("STRING",{"multiline": True, "default": "", "tooltip": "aka trigger word, if specified, will be added to the start of each caption, if no captions exist, will be used on it's own"}),
Expand All @@ -174,56 +180,99 @@ def INPUT_TYPES(s):
FUNCTION = "create_config"
CATEGORY = "FluxTrainer"

def create_config(self, dataset_config, dataset_path, class_tokens, width, height, batch_size, num_repeats, enable_bucket,
bucket_no_upscale, min_bucket_reso, max_bucket_reso, regularization=None):

def add_dataset_to_config(self, dataset_signatures_to_drop, **kwargs):
new_dataset = {
"resolution": (width, height),
"batch_size": batch_size,
"enable_bucket": enable_bucket,
"bucket_no_upscale": bucket_no_upscale,
"min_bucket_reso": min_bucket_reso,
"max_bucket_reso": max_bucket_reso,
"subsets": [
{
"image_dir": dataset_path,
"class_tokens": class_tokens,
"num_repeats": num_repeats
}
]
"resolution": (kwargs["width"], kwargs["height"]),
"batch_size": kwargs["batch_size"],
"enable_bucket": kwargs["enable_bucket"],
"bucket_no_upscale": kwargs["bucket_no_upscale"],
"min_bucket_reso": kwargs["min_bucket_reso"],
"max_bucket_reso": kwargs["max_bucket_reso"],
"subsets": [{
"image_dir": kwargs["dataset_path"],
"class_tokens": kwargs["class_tokens"],
"num_repeats": kwargs["num_repeats"],
}],
}
if regularization is not None:
new_dataset["subsets"].append(regularization)
if kwargs.get("regularization") is not None:
new_dataset["subsets"].append(kwargs["regularization"])

# Generate a signature for the new dataset
new_dataset_signature = self.generate_signature(new_dataset)

# Load the existing datasets
existing_datasets = json.loads(dataset_config["datasets"])
existing_datasets = json.loads(kwargs["dataset_config"]["datasets"])

# Remove the previously added dataset if it exists
if self.previous_dataset_signature:
if len(dataset_signatures_to_drop) > 0:
existing_datasets["datasets"] = [
ds for ds in existing_datasets["datasets"]
if self.generate_signature(ds) != self.previous_dataset_signature
if self.generate_signature(ds) not in dataset_signatures_to_drop
]

# Add the new dataset
existing_datasets["datasets"].append(new_dataset)

# Store the new dataset signature for future runs
self.previous_dataset_signature = new_dataset_signature
self.previous_dataset_signatures.add(new_dataset_signature)

# Convert back to JSON and update dataset_config
updated_dataset_json = json.dumps(existing_datasets, indent=2)
dataset_config["datasets"] = updated_dataset_json
kwargs["dataset_config"]["datasets"] = updated_dataset_json

return dataset_config,
return kwargs["dataset_config"],

def generate_signature(self, dataset):
# Create a unique signature for the dataset based on its attributes
return json.dumps(dataset, sort_keys=True)

class TrainDatasetAdd(TrainDatasetAddBase):
@classmethod
def INPUT_TYPES(s):
d = copy.deepcopy(s.base_input_types())
d["required"].update({
"width": ("INT",{"min": s.MIN_WIDTH, "default": s.DEF_WIDTH, "tooltip": "base resolution width"}),
"height": ("INT",{"min": s.MIN_HEIGHT, "default": s.DEF_HEIGHT, "tooltip": "base resolution height"}),
})
return d

def create_config(self, **kwargs):
return self.add_dataset_to_config(self.previous_dataset_signatures, **kwargs)

class TrainDatasetAddMultiReso(TrainDatasetAddBase):
@classmethod
def INPUT_TYPES(s):
d = copy.deepcopy(s.base_input_types())
d["required"].update({
"width": ("STRING",{"multiline": False, "default": str(s.DEF_WIDTH), "tooltip": "comma-separated list of base resolution widths"}),
"height": ("STRING",{"multiline": False, "default": str(s.DEF_HEIGHT), "tooltip": "comma-separated list of base resolution heights"}),
})
return d

def create_config(self, **kwargs):
def parse_values(text, name, min_value):
try:
values = [int(v) for v in text.split(",")]
assert len(values) > 0
except Exception as e:
raise ValueError(f'Invalid {name} value: "{text}"')
assert min(values) >= min_value, f"Minimal value for {name} is {min_value}"
return values

def parse_resolutions(width_text, height_text):
widths = parse_values(width_text, "width", TrainDatasetAddBase.MIN_WIDTH)
heights = parse_values(height_text, "height", TrainDatasetAddBase.MIN_HEIGHT)
assert len(widths) == len(heights), f"Unmatched: {len(widths)} widths vs {len(heights)} heights"
return zip(widths, heights)

resolutions = parse_resolutions(kwargs["width"], kwargs["height"])
signatures = self.previous_dataset_signatures.copy()
for width, height in resolutions:
kwargs["width"] = width
kwargs["height"] = height
result = self.add_dataset_to_config(signatures, **kwargs)
return result

class OptimizerConfig:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -1560,6 +1609,7 @@ def extract(self, original_model, finetuned_model, output_path, dim, save_dtype,
"FluxTrainModelSelect": FluxTrainModelSelect,
"TrainDatasetGeneralConfig": TrainDatasetGeneralConfig,
"TrainDatasetAdd": TrainDatasetAdd,
"TrainDatasetAddMultiReso": TrainDatasetAddMultiReso,
"FluxTrainLoop": FluxTrainLoop,
"VisualizeLoss": VisualizeLoss,
"FluxTrainValidate": FluxTrainValidate,
Expand All @@ -1581,8 +1631,9 @@ def extract(self, original_model, finetuned_model, output_path, dim, save_dtype,
"InitFluxLoRATraining": "Init Flux LoRA Training",
"InitFluxTraining": "Init Flux Training",
"FluxTrainModelSelect": "FluxTrain ModelSelect",
"TrainDatasetGeneralConfig": "TrainDatasetGeneralConfig",
"TrainDatasetAdd": "TrainDatasetAdd",
"TrainDatasetGeneralConfig": "Train Dataset General Config",
"TrainDatasetAdd": "Train Dataset Add",
"TrainDatasetAddMultiReso": "Train Dataset Add Multi Resolutions",
"FluxTrainLoop": "Flux Train Loop",
"VisualizeLoss": "Visualize Loss",
"FluxTrainValidate": "Flux Train Validate",
Expand Down