-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
## Summary Nodes to support SD3.5 txt2img generations * adds SD3.5 to starter models * adds default workflow for SD3.5 txt2img ## Related Issues / Discussions <!--WHEN APPLICABLE: List any related issues or discussions on github or discord. If this PR closes an issue, please use the "Closes #1234" format, so that the issue will be automatically closed when the PR merges.--> ## QA Instructions <!--WHEN APPLICABLE: Describe how you have tested the changes in this PR. Provide enough detail that a reviewer can reproduce your tests.--> ## Merge Plan <!--WHEN APPLICABLE: Large PRs, or PRs that touch sensitive things like DB schemas, may need some care when merging. For example, a careful rebase by the change author, timing to not interfere with a pending release, or a message to contributors on discord after merging.--> ## Checklist - [ ] _The PR has a short but descriptive title, suitable for a changelog_ - [ ] _Tests added / updated (if applicable)_ - [ ] _Documentation added / updated (if applicable)_
- Loading branch information
Showing
41 changed files
with
2,523 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
from typing import Literal | ||
|
||
from invokeai.app.invocations.baseinvocation import ( | ||
BaseInvocation, | ||
BaseInvocationOutput, | ||
Classification, | ||
invocation, | ||
invocation_output, | ||
) | ||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType | ||
from invokeai.app.invocations.model import CLIPField, ModelIdentifierField, T5EncoderField, TransformerField, VAEField | ||
from invokeai.app.services.shared.invocation_context import InvocationContext | ||
from invokeai.backend.flux.util import max_seq_lengths | ||
from invokeai.backend.model_manager.config import ( | ||
CheckpointConfigBase, | ||
SubModelType, | ||
) | ||
|
||
|
||
@invocation_output("flux_model_loader_output") | ||
class FluxModelLoaderOutput(BaseInvocationOutput): | ||
"""Flux base model loader output""" | ||
|
||
transformer: TransformerField = OutputField(description=FieldDescriptions.transformer, title="Transformer") | ||
clip: CLIPField = OutputField(description=FieldDescriptions.clip, title="CLIP") | ||
t5_encoder: T5EncoderField = OutputField(description=FieldDescriptions.t5_encoder, title="T5 Encoder") | ||
vae: VAEField = OutputField(description=FieldDescriptions.vae, title="VAE") | ||
max_seq_len: Literal[256, 512] = OutputField( | ||
description="The max sequence length to used for the T5 encoder. (256 for schnell transformer, 512 for dev transformer)", | ||
title="Max Seq Length", | ||
) | ||
|
||
|
||
@invocation( | ||
"flux_model_loader", | ||
title="Flux Main Model", | ||
tags=["model", "flux"], | ||
category="model", | ||
version="1.0.4", | ||
classification=Classification.Prototype, | ||
) | ||
class FluxModelLoaderInvocation(BaseInvocation): | ||
"""Loads a flux base model, outputting its submodels.""" | ||
|
||
model: ModelIdentifierField = InputField( | ||
description=FieldDescriptions.flux_model, | ||
ui_type=UIType.FluxMainModel, | ||
input=Input.Direct, | ||
) | ||
|
||
t5_encoder_model: ModelIdentifierField = InputField( | ||
description=FieldDescriptions.t5_encoder, ui_type=UIType.T5EncoderModel, input=Input.Direct, title="T5 Encoder" | ||
) | ||
|
||
clip_embed_model: ModelIdentifierField = InputField( | ||
description=FieldDescriptions.clip_embed_model, | ||
ui_type=UIType.CLIPEmbedModel, | ||
input=Input.Direct, | ||
title="CLIP Embed", | ||
) | ||
|
||
vae_model: ModelIdentifierField = InputField( | ||
description=FieldDescriptions.vae_model, ui_type=UIType.FluxVAEModel, title="VAE" | ||
) | ||
|
||
def invoke(self, context: InvocationContext) -> FluxModelLoaderOutput: | ||
for key in [self.model.key, self.t5_encoder_model.key, self.clip_embed_model.key, self.vae_model.key]: | ||
if not context.models.exists(key): | ||
raise ValueError(f"Unknown model: {key}") | ||
|
||
transformer = self.model.model_copy(update={"submodel_type": SubModelType.Transformer}) | ||
vae = self.vae_model.model_copy(update={"submodel_type": SubModelType.VAE}) | ||
|
||
tokenizer = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.Tokenizer}) | ||
clip_encoder = self.clip_embed_model.model_copy(update={"submodel_type": SubModelType.TextEncoder}) | ||
|
||
tokenizer2 = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.Tokenizer2}) | ||
t5_encoder = self.t5_encoder_model.model_copy(update={"submodel_type": SubModelType.TextEncoder2}) | ||
|
||
transformer_config = context.models.get_config(transformer) | ||
assert isinstance(transformer_config, CheckpointConfigBase) | ||
|
||
return FluxModelLoaderOutput( | ||
transformer=TransformerField(transformer=transformer, loras=[]), | ||
clip=CLIPField(tokenizer=tokenizer, text_encoder=clip_encoder, loras=[], skipped_layers=0), | ||
t5_encoder=T5EncoderField(tokenizer=tokenizer2, text_encoder=t5_encoder), | ||
vae=VAEField(vae=vae), | ||
max_seq_len=max_seq_lengths[transformer_config.config_path], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.