diff --git a/PixArt/loader.py b/PixArt/loader.py index e84fe14..7f3df5c 100644 --- a/PixArt/loader.py +++ b/PixArt/loader.py @@ -23,7 +23,7 @@ def __init__(self, model_conf): def model_type(self, state_dict, prefix=""): return comfy.model_base.ModelType.EPS -def load_pixart(model_path, model_conf): +def load_pixart(model_path, model_conf, target_dtype): state_dict = comfy.utils.load_torch_file(model_path) state_dict = state_dict.get("model", state_dict) @@ -36,8 +36,11 @@ def load_pixart(model_path, model_conf): if "adaln_single.linear.weight" in state_dict: state_dict = convert_state_dict(state_dict) # Diffusers - parameters = comfy.utils.calculate_parameters(state_dict) - unet_dtype = model_management.unet_dtype(model_params=parameters) + if target_dtype is None: + parameters = comfy.utils.calculate_parameters(state_dict) + unet_dtype = model_management.unet_dtype(model_params=parameters) + else: + unet_dtype = target_dtype model_conf = EXM_PixArt(model_conf) # convert to object model = comfy.model_base.BaseModel( diff --git a/PixArt/nodes.py b/PixArt/nodes.py index 184e6c3..2a3e72c 100644 --- a/PixArt/nodes.py +++ b/PixArt/nodes.py @@ -9,6 +9,14 @@ from .loader import load_pixart from .sampler import sample_pixart +dtypes = [ + "default", + "auto (comfy)", + "float32", + "float16", + "bfloat16", +] + class PixArtCheckpointLoader: @classmethod def INPUT_TYPES(s): @@ -16,6 +24,7 @@ def INPUT_TYPES(s): "required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), "model": (list(pixart_conf.keys()),), + "dtype": (dtypes,), } } RETURN_TYPES = ("MODEL",) @@ -24,12 +33,28 @@ def INPUT_TYPES(s): CATEGORY = "ExtraModels/PixArt" TITLE = "PixArt Checkpoint Loader" - def load_checkpoint(self, ckpt_name, model): + def load_checkpoint(self, ckpt_name, model, dtype): ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) model_conf = pixart_conf[model] + + target_dtype = None + if dtype == "default": + target_dtype = torch.float16 + elif dtype == 'auto (comfy)': + target_dtype = None + elif dtype == 'float32': + target_dtype = torch.float32 + elif dtype == 'float16': + target_dtype = torch.float16 + elif dtype == 'bfloat16': + target_dtype = torch.bfloat16 + else: + raise ValueError(f"Invalid dtype: {dtype}") + model = load_pixart( model_path = ckpt_path, model_conf = model_conf, + target_dtype = target_dtype, ) return (model,)