Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: add dtype selection to PixArt #12

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions PixArt/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, model_conf):
def model_type(self, state_dict, prefix=""):
return comfy.model_base.ModelType.EPS

def load_pixart(model_path, model_conf):
def load_pixart(model_path, model_conf, target_dtype):
state_dict = comfy.utils.load_torch_file(model_path)
state_dict = state_dict.get("model", state_dict)

Expand All @@ -36,8 +36,11 @@ def load_pixart(model_path, model_conf):
if "adaln_single.linear.weight" in state_dict:
state_dict = convert_state_dict(state_dict) # Diffusers

parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
if target_dtype is None:
parameters = comfy.utils.calculate_parameters(state_dict)
unet_dtype = model_management.unet_dtype(model_params=parameters)
else:
unet_dtype = target_dtype

model_conf = EXM_PixArt(model_conf) # convert to object
model = comfy.model_base.BaseModel(
Expand Down
27 changes: 26 additions & 1 deletion PixArt/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@
from .loader import load_pixart
from .sampler import sample_pixart

dtypes = [
"default",
"auto (comfy)",
"float32",
"float16",
"bfloat16",
]

class PixArtCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
"model": (list(pixart_conf.keys()),),
"dtype": (dtypes,),
}
}
RETURN_TYPES = ("MODEL",)
Expand All @@ -24,12 +33,28 @@ def INPUT_TYPES(s):
CATEGORY = "ExtraModels/PixArt"
TITLE = "PixArt Checkpoint Loader"

def load_checkpoint(self, ckpt_name, model):
def load_checkpoint(self, ckpt_name, model, dtype):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
model_conf = pixart_conf[model]

target_dtype = None
if dtype == "default":
target_dtype = torch.float16
elif dtype == 'auto (comfy)':
target_dtype = None
elif dtype == 'float32':
target_dtype = torch.float32
elif dtype == 'float16':
target_dtype = torch.float16
elif dtype == 'bfloat16':
target_dtype = torch.bfloat16
else:
raise ValueError(f"Invalid dtype: {dtype}")

model = load_pixart(
model_path = ckpt_path,
model_conf = model_conf,
target_dtype = target_dtype,
)
return (model,)

Expand Down