From 34c1ff13d30483e2a2b1e75879557024a8524a89 Mon Sep 17 00:00:00 2001 From: wanderor Date: Tue, 17 Sep 2024 16:59:25 +0800 Subject: [PATCH] Adds TrainDatasetAddMultiReso Adds TrainDatasetAddMultiReso node, which is similar to TrainDatasetAdd node and supports specifying multiple resolutions. --- nodes.py | 117 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 33 deletions(-) diff --git a/nodes.py b/nodes.py index d9d745c..2bb1e94 100644 --- a/nodes.py +++ b/nodes.py @@ -1,3 +1,4 @@ +import copy import os import torch from torchvision import transforms @@ -144,17 +145,22 @@ def create_config(self, dataset_path, class_tokens, num_repeats): } return reg_subset, - -class TrainDatasetAdd: + +class TrainDatasetAddBase: + MIN_WIDTH = 64 + MIN_HEIGHT = 64 + DEF_WIDTH = 1024 + DEF_HEIGHT = 1024 + def __init__(self): - self.previous_dataset_signature = None - + self.previous_dataset_signatures = set() + @classmethod - def INPUT_TYPES(s): + def base_input_types(s): return {"required": { "dataset_config": ("JSON",), - "width": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution width"}), - "height": ("INT",{"min": 64, "default": 1024, "tooltip": "base resolution height"}), + "width": None, # placeholder to preserve order + "height": None, # for backward compatibility "batch_size": ("INT",{"min": 1, "default": 2, "tooltip": "Higher batch size uses more memory and generalizes the training more"}), "dataset_path": ("STRING",{"multiline": True, "default": "", "tooltip": "path to dataset, root is the 'ComfyUI' folder, with windows portable 'ComfyUI_windows_portable'"}), "class_tokens": ("STRING",{"multiline": True, "default": "", "tooltip": "aka trigger word, if specified, will be added to the start of each caption, if no captions exist, will be used on it's own"}), @@ -174,56 +180,99 @@ def INPUT_TYPES(s): FUNCTION = "create_config" CATEGORY = "FluxTrainer" - def create_config(self, dataset_config, dataset_path, class_tokens, width, height, batch_size, num_repeats, enable_bucket, - bucket_no_upscale, min_bucket_reso, max_bucket_reso, regularization=None): - + def add_dataset_to_config(self, dataset_signatures_to_drop, **kwargs): new_dataset = { - "resolution": (width, height), - "batch_size": batch_size, - "enable_bucket": enable_bucket, - "bucket_no_upscale": bucket_no_upscale, - "min_bucket_reso": min_bucket_reso, - "max_bucket_reso": max_bucket_reso, - "subsets": [ - { - "image_dir": dataset_path, - "class_tokens": class_tokens, - "num_repeats": num_repeats - } - ] + "resolution": (kwargs["width"], kwargs["height"]), + "batch_size": kwargs["batch_size"], + "enable_bucket": kwargs["enable_bucket"], + "bucket_no_upscale": kwargs["bucket_no_upscale"], + "min_bucket_reso": kwargs["min_bucket_reso"], + "max_bucket_reso": kwargs["max_bucket_reso"], + "subsets": [{ + "image_dir": kwargs["dataset_path"], + "class_tokens": kwargs["class_tokens"], + "num_repeats": kwargs["num_repeats"], + }], } - if regularization is not None: - new_dataset["subsets"].append(regularization) + if kwargs.get("regularization") is not None: + new_dataset["subsets"].append(kwargs["regularization"]) # Generate a signature for the new dataset new_dataset_signature = self.generate_signature(new_dataset) # Load the existing datasets - existing_datasets = json.loads(dataset_config["datasets"]) + existing_datasets = json.loads(kwargs["dataset_config"]["datasets"]) # Remove the previously added dataset if it exists - if self.previous_dataset_signature: + if len(dataset_signatures_to_drop) > 0: existing_datasets["datasets"] = [ ds for ds in existing_datasets["datasets"] - if self.generate_signature(ds) != self.previous_dataset_signature + if self.generate_signature(ds) not in dataset_signatures_to_drop ] # Add the new dataset existing_datasets["datasets"].append(new_dataset) # Store the new dataset signature for future runs - self.previous_dataset_signature = new_dataset_signature + self.previous_dataset_signatures.add(new_dataset_signature) # Convert back to JSON and update dataset_config updated_dataset_json = json.dumps(existing_datasets, indent=2) - dataset_config["datasets"] = updated_dataset_json + kwargs["dataset_config"]["datasets"] = updated_dataset_json - return dataset_config, + return kwargs["dataset_config"], def generate_signature(self, dataset): # Create a unique signature for the dataset based on its attributes return json.dumps(dataset, sort_keys=True) +class TrainDatasetAdd(TrainDatasetAddBase): + @classmethod + def INPUT_TYPES(s): + d = copy.deepcopy(s.base_input_types()) + d["required"].update({ + "width": ("INT",{"min": s.MIN_WIDTH, "default": s.DEF_WIDTH, "tooltip": "base resolution width"}), + "height": ("INT",{"min": s.MIN_HEIGHT, "default": s.DEF_HEIGHT, "tooltip": "base resolution height"}), + }) + return d + + def create_config(self, **kwargs): + return self.add_dataset_to_config(self.previous_dataset_signatures, **kwargs) + +class TrainDatasetAddMultiReso(TrainDatasetAddBase): + @classmethod + def INPUT_TYPES(s): + d = copy.deepcopy(s.base_input_types()) + d["required"].update({ + "width": ("STRING",{"multiline": False, "default": str(s.DEF_WIDTH), "tooltip": "comma-separated list of base resolution widths"}), + "height": ("STRING",{"multiline": False, "default": str(s.DEF_HEIGHT), "tooltip": "comma-separated list of base resolution heights"}), + }) + return d + + def create_config(self, **kwargs): + def parse_values(text, name, min_value): + try: + values = [int(v) for v in text.split(",")] + assert len(values) > 0 + except Exception as e: + raise ValueError(f'Invalid {name} value: "{text}"') + assert min(values) >= min_value, f"Minimal value for {name} is {min_value}" + return values + + def parse_resolutions(width_text, height_text): + widths = parse_values(width_text, "width", TrainDatasetAddBase.MIN_WIDTH) + heights = parse_values(height_text, "height", TrainDatasetAddBase.MIN_HEIGHT) + assert len(widths) == len(heights), f"Unmatched: {len(widths)} widths vs {len(heights)} heights" + return zip(widths, heights) + + resolutions = parse_resolutions(kwargs["width"], kwargs["height"]) + signatures = self.previous_dataset_signatures.copy() + for width, height in resolutions: + kwargs["width"] = width + kwargs["height"] = height + result = self.add_dataset_to_config(signatures, **kwargs) + return result + class OptimizerConfig: @classmethod def INPUT_TYPES(s): @@ -1560,6 +1609,7 @@ def extract(self, original_model, finetuned_model, output_path, dim, save_dtype, "FluxTrainModelSelect": FluxTrainModelSelect, "TrainDatasetGeneralConfig": TrainDatasetGeneralConfig, "TrainDatasetAdd": TrainDatasetAdd, + "TrainDatasetAddMultiReso": TrainDatasetAddMultiReso, "FluxTrainLoop": FluxTrainLoop, "VisualizeLoss": VisualizeLoss, "FluxTrainValidate": FluxTrainValidate, @@ -1581,8 +1631,9 @@ def extract(self, original_model, finetuned_model, output_path, dim, save_dtype, "InitFluxLoRATraining": "Init Flux LoRA Training", "InitFluxTraining": "Init Flux Training", "FluxTrainModelSelect": "FluxTrain ModelSelect", - "TrainDatasetGeneralConfig": "TrainDatasetGeneralConfig", - "TrainDatasetAdd": "TrainDatasetAdd", + "TrainDatasetGeneralConfig": "Train Dataset General Config", + "TrainDatasetAdd": "Train Dataset Add", + "TrainDatasetAddMultiReso": "Train Dataset Add Multi Resolutions", "FluxTrainLoop": "Flux Train Loop", "VisualizeLoss": "Visualize Loss", "FluxTrainValidate": "Flux Train Validate",