diff --git a/HunYuanDiT/LICENSE-HYDiT b/HunYuanDiT/LICENSE-HYDiT new file mode 100644 index 0000000..61ea65d --- /dev/null +++ b/HunYuanDiT/LICENSE-HYDiT @@ -0,0 +1,74 @@ +TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT +Tencent Hunyuan Release Date: 2024/5/14 +By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately. +1. DEFINITIONS. +a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A. +b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of the Hunyuan Works or any portion or element thereof set forth herein. +c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent. +d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means. +e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use. +f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement. +g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives. +h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service. +i. “Tencent,” “We” or “Us” shall mean THL A29 Limited. +j. “Tencent Hunyuan” shall mean the large language models, image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us at https://huggingface.co/Tencent-Hunyuan/HunyuanDiT and https://github.com/Tencent/HunyuanDiT . +k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof. +l. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You. +m. “including” shall mean including but not limited to. +2. GRANT OF RIGHTS. +We grant You a non-exclusive, worldwide, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy. +3. DISTRIBUTION. +You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, provided that You meet all of the following conditions: +a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement; +b. You must cause any modified files to carry prominent notices stating that You changed the files; +c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and +d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.” +You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement. If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You. +4. ADDITIONAL COMMERCIAL TERMS. +If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights. +5. RULES OF USE. +a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b). +b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof). +6. INTELLECTUAL PROPERTY. +a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You. +b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent. +c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works. +d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses. +7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY. +a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto. +b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT. +c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. +8. SURVIVAL AND TERMINATION. +a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. +b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement. +9. GOVERNING LAW AND JURISDICTION. +a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. +b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute. +  + +EXHIBIT A +ACCEPTABLE USE POLICY + +Tencent reserves the right to update this Acceptable Use Policy from time to time. +Last modified: 2024/5/14 + +Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives: +1. In any way that violates any applicable national, federal, state, local, international or any other law or regulation; +2. To harm Yourself or others; +3. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others; +4. To override or circumvent the safety guardrails and safeguards We have put in place; +5. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; +6. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections; +7. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement; +8. To intentionally defame, disparage or otherwise harass others; +9. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems; +10. To generate or disseminate personal identifiable information with the purpose of harming others; +11. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated; +12. To impersonate another individual without consent, authorization, or legal right; +13. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance); +14. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions; +15. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism; +16. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics; +17. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; +18. For military purposes; +19. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices. diff --git a/HunYuanDiT/conf.py b/HunYuanDiT/conf.py new file mode 100644 index 0000000..d057a39 --- /dev/null +++ b/HunYuanDiT/conf.py @@ -0,0 +1,44 @@ +""" +List of all HYDiT model types / settings +""" +sampling_settings = { + "beta_schedule" : "linear", + "linear_start" : 0.00085, + "linear_end" : 0.03, + "timesteps" : 1000, +} + +from argparse import Namespace +hydit_args = Namespace(**{ # normally from argparse + "infer_mode": "torch", + "norm": "layer", + "learn_sigma": True, + "text_states_dim": 1024, + "text_states_dim_t5": 2048, + "text_len": 77, + "text_len_t5": 256, +}) + +hydit_conf = { + "G/2": { # Seems to be the main one + "unet_config": { + "depth" : 40, + "num_heads" : 16, + "patch_size" : 2, + "hidden_size" : 1408, + "mlp_ratio" : 4.3637, + "input_size": (1024//8, 1024//8), + "args": hydit_args, + }, + "sampling_settings" : sampling_settings, + }, +} + +# these are the same as regular DiT, I think +from ..DiT.conf import dit_conf +for name in ["XL/2", "L/2", "B/2"]: + hydit_conf[name] = { + "unet_config": dit_conf[name]["unet_config"].copy(), + "sampling_settings": sampling_settings, + } + hydit_conf[name]["unet_config"]["args"] = hydit_args diff --git a/HunYuanDiT/config_clip.json b/HunYuanDiT/config_clip.json new file mode 100644 index 0000000..f629874 --- /dev/null +++ b/HunYuanDiT/config_clip.json @@ -0,0 +1,34 @@ +{ + "_name_or_path": "hfl/chinese-roberta-wwm-ext-large", + "architectures": [ + "BertModel" + ], + "attention_probs_dropout_prob": 0.1, + "bos_token_id": 0, + "classifier_dropout": null, + "directionality": "bidi", + "eos_token_id": 2, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 1024, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-12, + "max_position_embeddings": 512, + "model_type": "bert", + "num_attention_heads": 16, + "num_hidden_layers": 24, + "output_past": true, + "pad_token_id": 0, + "pooler_fc_size": 768, + "pooler_num_attention_heads": 12, + "pooler_num_fc_layers": 3, + "pooler_size_per_head": 128, + "pooler_type": "first_token_transform", + "position_embedding_type": "absolute", + "torch_dtype": "float32", + "transformers_version": "4.22.1", + "type_vocab_size": 2, + "use_cache": true, + "vocab_size": 47020 +} diff --git a/HunYuanDiT/config_mt5.json b/HunYuanDiT/config_mt5.json new file mode 100644 index 0000000..d5c7028 --- /dev/null +++ b/HunYuanDiT/config_mt5.json @@ -0,0 +1,33 @@ +{ + "_name_or_path": "mt5", + "architectures": [ + "MT5ForConditionalGeneration" + ], + "classifier_dropout": 0.0, + "d_ff": 5120, + "d_kv": 64, + "d_model": 2048, + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": true, + "is_gated_act": true, + "layer_norm_epsilon": 1e-06, + "model_type": "mt5", + "num_decoder_layers": 24, + "num_heads": 32, + "num_layers": 24, + "output_past": true, + "pad_token_id": 0, + "relative_attention_max_distance": 128, + "relative_attention_num_buckets": 32, + "tie_word_embeddings": false, + "tokenizer_class": "T5Tokenizer", + "torch_dtype": "float16", + "transformers_version": "4.40.2", + "use_cache": true, + "vocab_size": 250112 +} diff --git a/HunYuanDiT/loader.py b/HunYuanDiT/loader.py new file mode 100644 index 0000000..35940de --- /dev/null +++ b/HunYuanDiT/loader.py @@ -0,0 +1,77 @@ +import comfy.supported_models_base +import comfy.latent_formats +import comfy.model_patcher +import comfy.model_base +import comfy.utils +import comfy.conds +import torch +from comfy import model_management +from tqdm import tqdm + +class EXM_HYDiT(comfy.supported_models_base.BASE): + unet_config = {} + unet_extra_config = {} + latent_format = comfy.latent_formats.SDXL + + def __init__(self, model_conf): + self.unet_config = model_conf.get("unet_config", {}) + self.sampling_settings = model_conf.get("sampling_settings", {}) + self.latent_format = self.latent_format() + # UNET is handled by extension + self.unet_config["disable_unet_model_creation"] = True + + def model_type(self, state_dict, prefix=""): + return comfy.model_base.ModelType.V_PREDICTION + +class EXM_HYDiT_Model(comfy.model_base.BaseModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + for name in ["context_t5", "context_mask", "context_t5_mask"]: + out[name] = comfy.conds.CONDRegular(kwargs[name]) + + return out + +def load_hydit(model_path, model_conf): + state_dict = comfy.utils.load_torch_file(model_path) + state_dict = state_dict.get("model", state_dict) + + parameters = comfy.utils.calculate_parameters(state_dict) + unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.unet_offload_device() + + # ignore fp8/etc and use directly for now + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + if manual_cast_dtype: + print(f"HunYuanDiT: falling back to {manual_cast_dtype}") + unet_dtype = manual_cast_dtype + + model_conf = EXM_HYDiT(model_conf) + model = EXM_HYDiT_Model( + model_conf, + model_type=comfy.model_base.ModelType.V_PREDICTION, + device=model_management.get_torch_device() + ) + + from .models.models import HunYuanDiT + model.diffusion_model = HunYuanDiT( + **model_conf.unet_config, + log_fn=tqdm.write, + ) + + model.diffusion_model.load_state_dict(state_dict) + model.diffusion_model.dtype = unet_dtype + model.diffusion_model.eval() + model.diffusion_model.to(unet_dtype) + + model_patcher = comfy.model_patcher.ModelPatcher( + model, + load_device = load_device, + offload_device = offload_device, + current_device = "cpu", + ) + return model_patcher diff --git a/HunYuanDiT/models/attn_layers.py b/HunYuanDiT/models/attn_layers.py new file mode 100644 index 0000000..4308af9 --- /dev/null +++ b/HunYuanDiT/models/attn_layers.py @@ -0,0 +1,377 @@ +import torch +import torch.nn as nn +from typing import Tuple, Union, Optional + +try: + import flash_attn + if hasattr(flash_attn, '__version__') and int(flash_attn.__version__[0]) == 2: + from flash_attn.flash_attn_interface import flash_attn_kvpacked_func + from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention + else: + from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention +except Exception as e: + print(f'flash_attn import failed: {e}') + + +def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: Optional[torch.Tensor], + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + if xk is not None: + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + if xk is not None: + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +class FlashSelfMHAModified(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + dim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.dim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs) + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop) + self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // 2), RoPE for image + """ + b, s, d = x.shape + + qkv = self.Wqkv(x) + qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d] + q, k, v = qkv.unbind(dim=2) # [b, s, h, d] + q = self.q_norm(q).half() # [b, s, h, d] + k = self.k_norm(k).half() + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, kk = apply_rotary_emb(q, k, freqs_cis_img) + assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}' + q, k = qq, kk + + qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d] + context = self.inner_attn(qkv) + out = self.out_proj(context.view(b, s, d)) + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class FlashCrossMHAModified(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.scale = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) + + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.inner_attn = FlashCrossAttention(attention_dropout=attn_drop) + self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim) + y: torch.Tensor + (batch, seqlen2, hidden_dim2) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // num_heads), RoPE for image + """ + b, s1, _ = x.shape # [b, s1, D] + _, s2, _ = y.shape # [b, s2, 1024] + + q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] + kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d] + k, v = kv.unbind(dim=2) # [b, s2, h, d] + q = self.q_norm(q).half() # [b, s1, h, d] + k = self.k_norm(k).half() # [b, s2, h, d] + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, _ = apply_rotary_emb(q, None, freqs_cis_img) + assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}' + q = qq # [b, s1, h, d] + kv = torch.stack([k, v], dim=2) # [b, s1, 2, h, d] + context = self.inner_attn(q, kv) # [b, s1, h, d] + context = context.view(b, s1, -1) # [b, s1, D] + + out = self.out_proj(context) + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class CrossAttention(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + self.scale = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) + + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim) + y: torch.Tensor + (batch, seqlen2, hidden_dim2) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // 2), RoPE for image + """ + b, s1, c = x.shape # [b, s1, D] + _, s2, c = y.shape # [b, s2, 1024] + + q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] + kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d] + k, v = kv.unbind(dim=2) # [b, s, h, d] + q = self.q_norm(q) + k = self.k_norm(k) + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, _ = apply_rotary_emb(q, None, freqs_cis_img) + assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}' + q = qq + + q = q * self.scale + q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C + k = k.permute(0, 2, 3, 1).contiguous() # k -> B, L2, H, C - B, H, C, L2 + attn = q @ k # attn -> B, H, L1, L2 + attn = attn.softmax(dim=-1) # attn -> B, H, L1, L2 + attn = self.attn_drop(attn) + x = attn @ v.transpose(-2, -3) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C + context = x.transpose(1, 2) # context -> B, H, L1, C - B, L1, H, C + + context = context.contiguous().view(b, s1, -1) + + out = self.out_proj(context) # context.reshape - B, L1, -1 + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class Attention(nn.Module): + """ + We rename some layer names to align with flash attention + """ + def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, 'dim should be divisible by num_heads' + self.head_dim = self.dim // num_heads + # This assertion is aligned with flash attention + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + self.scale = self.head_dim ** -0.5 + + # qkv --> Wqkv + self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, freqs_cis_img=None): + B, N, C = x.shape + qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d] + q, k, v = qkv.unbind(0) # [b, h, s, d] + q = self.q_norm(q) # [b, h, s, d] + k = self.k_norm(k) # [b, h, s, d] + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True) + assert qq.shape == q.shape and kk.shape == k.shape, \ + f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}' + q, k = qq, kk + + q = q * self.scale + attn = q @ k.transpose(-2, -1) # [b, h, s, d] @ [b, h, d, s] + attn = attn.softmax(dim=-1) # [b, h, s, s] + attn = self.attn_drop(attn) + x = attn @ v # [b, h, s, d] + + x = x.transpose(1, 2).reshape(B, N, C) # [b, s, h, d] + x = self.out_proj(x) + x = self.proj_drop(x) + + out_tuple = (x,) + + return out_tuple diff --git a/HunYuanDiT/models/embedders.py b/HunYuanDiT/models/embedders.py new file mode 100644 index 0000000..9fe08cb --- /dev/null +++ b/HunYuanDiT/models/embedders.py @@ -0,0 +1,111 @@ +import math +import torch +import torch.nn as nn +from einops import repeat + +from timm.models.layers import to_2tuple + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, (tuple, list)) and len(img_size) == 2: + img_size = tuple(img_size) + else: + raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}") + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def update_image_size(self, img_size): + self.img_size = img_size + self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + def forward(self, x): + # B, C, H, W = x.shape + # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +def timestep_embedding(t, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线 + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(t, "b -> b d", d=dim) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/HunYuanDiT/models/models.py b/HunYuanDiT/models/models.py new file mode 100644 index 0000000..df1bbcb --- /dev/null +++ b/HunYuanDiT/models/models.py @@ -0,0 +1,454 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.vision_transformer import Mlp + +from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention +from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding +from .norm_layers import RMSNorm +from .poolers import AttentionPool +from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class FP32_Layernorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), + self.eps).to(origin_dtype) + + +class FP32_SiLU(nn.SiLU): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class HunYuanDiTBlock(nn.Module): + """ + A HunYuanDiT block with `add` conditioning. + """ + def __init__(self, + hidden_size, + c_emb_size, + num_heads, + mlp_ratio=4.0, + text_states_dim=1024, + use_flash_attn=False, + qk_norm=False, + norm_type="layer", + skip=False, + ): + super().__init__() + self.use_flash_attn = use_flash_attn + use_ele_affine = True + + if norm_type == "layer": + norm_layer = FP32_Layernorm + elif norm_type == "rms": + norm_layer = RMSNorm + else: + raise ValueError(f"Unknown norm_type: {norm_type}") + + # ========================= Self-Attention ========================= + self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6) + if use_flash_attn: + self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm) + else: + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm) + + # ========================= FFN ========================= + self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + + # ========================= Add ========================= + # Simply use add like SDXL. + self.default_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(c_emb_size, hidden_size, bias=True) + ) + + # ========================= Cross-Attention ========================= + if use_flash_attn: + self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, + qk_norm=qk_norm) + else: + self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, + qk_norm=qk_norm) + self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6) + + # ========================= Skip Connection ========================= + if skip: + self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6) + self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) + else: + self.skip_linear = None + + def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None): + # Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + # Self-Attention + shift_msa = self.default_modulation(c).unsqueeze(dim=1) + attn_inputs = ( + self.norm1(x) + shift_msa, freq_cis_img, + ) + x = x + self.attn1(*attn_inputs)[0] + + # Cross-Attention + cross_inputs = ( + self.norm3(x), text_states, freq_cis_img + ) + x = x + self.attn2(*cross_inputs)[0] + + # FFN Layer + mlp_inputs = self.norm2(x) + x = x + self.mlp(mlp_inputs) + + return x + + +class FinalLayer(nn.Module): + """ + The final layer of HunYuanDiT. + """ + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class HunYuanDiT(nn.Module): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Parameters + ---------- + args: argparse.Namespace + The arguments parsed by argparse. + input_size: tuple + The size of the input image. + patch_size: int + The size of the patch. + in_channels: int + The number of input channels. + hidden_size: int + The hidden size of the transformer backbone. + depth: int + The number of transformer blocks. + num_heads: int + The number of attention heads. + mlp_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + log_fn: callable + The logging function. + """ + def __init__( + self, args, + input_size=(32, 32), + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + log_fn=print, + **kwargs, + ): + super().__init__() + self.args = args + self.log_fn = log_fn + self.depth = depth + self.learn_sigma = args.learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if args.learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size = hidden_size + self.head_size = hidden_size // num_heads + self.text_states_dim = args.text_states_dim + self.text_states_dim_t5 = args.text_states_dim_t5 + self.text_len = args.text_len + self.text_len_t5 = args.text_len_t5 + self.norm = args.norm + + use_flash_attn = args.infer_mode == 'fa' + if use_flash_attn: + log_fn(f" Enable Flash Attention.") + qk_norm = True # See http://arxiv.org/abs/2302.05442 for details. + + self.mlp_t5 = nn.Sequential( + nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True), + FP32_SiLU(), + nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True), + ) + # learnable replace + self.text_embedding_padding = nn.Parameter( + torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32)) + + # Attention pooling + self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024) + + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, hidden_size) + + # Image size and crop size conditions + self.extra_in_dim = 256 * 6 + hidden_size + + # Text embedding for `add` + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.extra_in_dim += 1024 + self.extra_embedder = nn.Sequential( + nn.Linear(self.extra_in_dim, hidden_size * 4), + FP32_SiLU(), + nn.Linear(hidden_size * 4, hidden_size, bias=True), + ) + + # Image embedding + num_patches = self.x_embedder.num_patches + log_fn(f" Number of tokens: {num_patches}") + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunYuanDiTBlock(hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + text_states_dim=self.text_states_dim, + use_flash_attn=use_flash_attn, + qk_norm=qk_norm, + norm_type=self.norm, + skip=layer > depth // 2, + ) + for layer in range(depth) + ]) + + self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels) + self.unpatchify_channels = self.out_channels + + # probably not needed when not training? + # self.initialize_weights() + + def forward_raw(self, + x, + t, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + cos_cis_img=None, + sin_cis_img=None, + return_dict=False, + ): + """ + Forward pass of the encoder. + + Parameters + ---------- + x: torch.Tensor + (B, D, H, W) + t: torch.Tensor + (B) + encoder_hidden_states: torch.Tensor + CLIP text embedding, (B, L_clip, D) + text_embedding_mask: torch.Tensor + CLIP text embedding mask, (B, L_clip) + encoder_hidden_states_t5: torch.Tensor + T5 text embedding, (B, L_t5, D) + text_embedding_mask_t5: torch.Tensor + T5 text embedding mask, (B, L_t5) + image_meta_size: torch.Tensor + (B, 6) + style: torch.Tensor + (B) + cos_cis_img: torch.Tensor + sin_cis_img: torch.Tensor + return_dict: bool + Whether to return a dictionary. + """ + + text_states = encoder_hidden_states # 2,77,1024 + text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 + text_states_mask = text_embedding_mask.bool() # 2,77 + text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 + b_t5, l_t5, c_t5 = text_states_t5.shape + text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)) + text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024 + clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1) + + clip_t5_mask = clip_t5_mask + text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states)) + + _, _, oh, ow = x.shape + th, tw = oh // self.patch_size, ow // self.patch_size + + # ========================= Build time and image embedding ========================= + t = self.t_embedder(t) + x = self.x_embedder(x) + + # Get image RoPE embedding according to `reso`lution. + freqs_cis_img = (cos_cis_img, sin_cis_img) + + # ========================= Concatenate all extra vectors ========================= + # Build text tokens with pooling + extra_vec = self.pooler(encoder_hidden_states_t5) + + # Build image meta size tokens + image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256] + # if self.args.use_fp16: + # image_meta_size = image_meta_size.half() + image_meta_size = image_meta_size.view(-1, 6 * 256) + extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + + # Build style tokens + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + + # Concatenate all extra vectors + c = t + self.extra_embedder(extra_vec.to(self.dtype)) # [B, D] + + # ========================= Forward pass through HunYuanDiT blocks ========================= + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.depth // 2: + skip = skips.pop() + x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) + else: + x = block(x, c, text_states, freqs_cis_img) # (N, L, D) + + if layer < (self.depth // 2 - 1): + skips.append(x) + + # ========================= Final layer ========================= + x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels) + x = self.unpatchify(x, th, tw) # (N, out_channels, H, W) + + if return_dict: + return {'x': x} + return x + + def calc_rope(self, height, width): + """ + Probably not the best in terms of perf to have this here + """ + th = height // 8 // self.patch_size + tw = width // 8 // self.patch_size + base_size = 512 // 8 // self.patch_size + start, stop = get_fill_resize_and_crop((th, tw), base_size) + sub_args = [start, stop, (th, tw)] + rope = get_2d_rotary_pos_embed(self.head_size, *sub_args) + return rope + + def forward(self, x, timesteps, context, context_mask=None, context_t5=None, context_t5_mask=None, image_meta_size=None, **kwargs): + """ + Forward pass that adapts comfy input to original forward function + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + timesteps: (N,) tensor of diffusion timesteps + context: (N, 1, 77, C) CLIP conditioning + context_t5: (N, 1, 256, C) MT5 conditioning + """ + # context_mask = torch.zeros(x.shape[0], 77, device=x.device) + # context_t5_mask = torch.zeros(x.shape[0], 256, device=x.device) + + # style + style = torch.as_tensor([0, 0] * (x.shape[0]//2), device=x.device) + + # image size - todo: from cond + width = x.shape[3] + height = x.shape[2] + src_size_cond = (width//2*16, height//2*16) + size_cond = list(src_size_cond) + [width*8, height*8, 0, 0] + image_meta_size = torch.as_tensor([size_cond] * x.shape[0], device=x.device) + + # RoPE + rope = self.calc_rope(*src_size_cond) + + # Run original forward pass + out = self.forward_raw( + x = x.to(self.dtype), + t = timesteps.to(self.dtype), + encoder_hidden_states = context.to(self.dtype), + text_embedding_mask = context_mask.to(self.dtype), + encoder_hidden_states_t5 = context_t5.to(self.dtype), + text_embedding_mask_t5 = context_t5_mask.to(self.dtype), + image_meta_size = image_meta_size.to(self.dtype), + style = style, + cos_cis_img = rope[0], + sin_cis_img = rope[1], + ) + + # return + out = out.to(torch.float) + if self.learn_sigma: + eps, rest = out[:, :self.in_channels], out[:, self.in_channels:] + return eps + else: + return out + + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.extra_embedder[0].weight, std=0.02) + nn.init.normal_(self.extra_embedder[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in HunYuanDiT blocks: + for block in self.blocks: + nn.init.constant_(block.default_modulation[-1].weight, 0) + nn.init.constant_(block.default_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.unpatchify_channels + p = self.x_embedder.patch_size[0] + # h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs diff --git a/HunYuanDiT/models/norm_layers.py b/HunYuanDiT/models/norm_layers.py new file mode 100644 index 0000000..5204ad9 --- /dev/null +++ b/HunYuanDiT/models/norm_layers.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +class GroupNorm32(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None): + super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype) + + def forward(self, x): + y = super().forward(x).to(x.dtype) + return y + +def normalization(channels, dtype=None): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype) diff --git a/HunYuanDiT/models/poolers.py b/HunYuanDiT/models/poolers.py new file mode 100644 index 0000000..a4adcac --- /dev/null +++ b/HunYuanDiT/models/poolers.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttentionPool(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) diff --git a/HunYuanDiT/models/posemb_layers.py b/HunYuanDiT/models/posemb_layers.py new file mode 100644 index 0000000..62c83df --- /dev/null +++ b/HunYuanDiT/models/posemb_layers.py @@ -0,0 +1,225 @@ +import torch +import numpy as np +from typing import Union + + +def _to_tuple(x): + if isinstance(x, int): + return x, x + else: + return x + + +def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率 + th, tw = _to_tuple(tgt) + h, w = _to_tuple(src) + + tr = th / tw # base 分辨率 + r = h / w # 目标分辨率 + + # resize + if r > tr: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来 + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +def get_meshgrid(start, *args): + if len(args) == 0: + # start is grid_size + num = _to_tuple(start) + start = (0, 0) + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start) + stop = _to_tuple(args[0]) + num = (stop[0] - start[0], stop[1] - start[1]) + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start) # 左上角 eg: 12,0 + stop = _to_tuple(args[0]) # 右下角 eg: 20,32 + num = _to_tuple(args[1]) # 目标大小 eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份 + grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + return grid + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = get_meshgrid(start, *args) # [2, H, w] + # grid_h = np.arange(grid_size, dtype=np.float32) + # grid_w = np.arange(grid_size, dtype=np.float32) + # grid = np.meshgrid(grid_w, grid_h) # here w goes first + # grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (W,H) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443 + +def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True): + """ + This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure. + + Parameters + ---------- + embed_dim: int + embedding dimension size + start: int or tuple of int + If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; + If len(args) == 2, start is start, args[0] is stop, args[1] is num. + use_real: bool + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns + ------- + pos_embed: torch.Tensor + [HW, D/2] + """ + grid = get_meshgrid(start, *args) # [2, H, w] + grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致 + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2] + + """ + if isinstance(pos, int): + pos = np.arange(pos) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + + +def calc_sizes(rope_img, patch_size, th, tw): + """ 计算 RoPE 的尺寸. """ + if rope_img == 'extend': + # 拓展模式 + sub_args = [(th, tw)] + elif rope_img.startswith('base'): + # 基于一个尺寸, 其他尺寸插值获得. + base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到 + start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角 + sub_args = [start, stop, (th, tw)] + else: + raise ValueError(f"Unknown rope_img: {rope_img}") + return sub_args + + +def init_image_posemb(rope_img, + resolutions, + patch_size, + hidden_size, + num_heads, + log_fn, + rope_real=True, + ): + freqs_cis_img = {} + for reso in resolutions: + th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size + sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角 + freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real) + log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) " + f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}") + return freqs_cis_img diff --git a/HunYuanDiT/nodes.py b/HunYuanDiT/nodes.py new file mode 100644 index 0000000..3b17c29 --- /dev/null +++ b/HunYuanDiT/nodes.py @@ -0,0 +1,169 @@ +import os +import folder_paths + +from .conf import hydit_conf +from .loader import load_hydit + +class HYDiTCheckpointLoader: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), + "model": (list(hydit_conf.keys()),{"default":"G/2"}), + } + } + RETURN_TYPES = ("MODEL",) + RETURN_NAMES = ("model",) + FUNCTION = "load_checkpoint" + CATEGORY = "ExtraModels/HunyuanDiT" + TITLE = "Hunyuan DiT Checkpoint Loader" + + def load_checkpoint(self, ckpt_name, model): + ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + model_conf = hydit_conf[model] + model = load_hydit( + model_path = ckpt_path, + model_conf = model_conf, + ) + return (model,) + +#### temp stuff for the text encoder #### +import torch +from .tenc import load_clip, load_t5 +from ..utils.dtype import string_to_dtype +dtypes = [ + "default", + "auto (comfy)", + "FP32", + "FP16", + "BF16" +] + +class HYDiTTextEncoderLoader: + @classmethod + def INPUT_TYPES(s): + devices = ["auto", "cpu", "gpu"] + # hack for using second GPU as offload + for k in range(1, torch.cuda.device_count()): + devices.append(f"cuda:{k}") + return { + "required": { + "clip_name": (folder_paths.get_filename_list("clip"),), + "mt5_name": (folder_paths.get_filename_list("t5"),), + "device": (devices, {"default":"cpu"}), + "dtype": (dtypes,), + } + } + + RETURN_TYPES = ("CLIP", "T5") + FUNCTION = "load_model" + CATEGORY = "ExtraModels/HunyuanDiT" + TITLE = "Hunyuan DiT Text Encoder Loader" + + def load_model(self, clip_name, mt5_name, device, dtype): + dtype = string_to_dtype(dtype, "text_encoder") + if device == "cpu": + assert dtype in [None, torch.float32, torch.bfloat16], f"Can't use dtype '{dtype}' with CPU! Set dtype to 'default' or 'bf16'." + + clip = load_clip( + model_path = folder_paths.get_full_path("clip", clip_name), + device = device, + dtype = dtype, + ) + t5 = load_t5( + model_path = folder_paths.get_full_path("t5", mt5_name), + device = device, + dtype = dtype, + ) + return(clip, t5) + +class HYDiTTextEncode: + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "text": ("STRING", {"multiline": True}), + "text_t5": ("STRING", {"multiline": True}), + "CLIP": ("CLIP",), + "T5": ("T5",), + } + } + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + CATEGORY = "ExtraModels/HunyuanDiT" + TITLE = "Hunyuan DiT Text Encode" + + def encode(self, text, text_t5, CLIP, T5): + # T5 + t5_pre = T5.tokenizer( + text, + max_length = T5.cond_stage_model.max_length, + padding = 'max_length', + truncation = True, + return_attention_mask = True, + add_special_tokens = True, + return_tensors = 'pt' + ) + t5_mask = t5_pre["attention_mask"] + with torch.no_grad(): + t5_outs = T5.cond_stage_model.transformer( + input_ids = t5_pre["input_ids"].to(T5.load_device), + attention_mask = t5_mask.to(T5.load_device), + output_hidden_states = True, + ) + # to-do: replace -1 for clip skip + t5_embs = t5_outs["hidden_states"][-1].float().cpu() + + # "clip" + clip_pre = CLIP.tokenizer( + text, + max_length = CLIP.cond_stage_model.max_length, + padding = 'max_length', + truncation = True, + return_attention_mask = True, + add_special_tokens = True, + return_tensors = 'pt' + ) + clip_mask = clip_pre["attention_mask"] + with torch.no_grad(): + clip_outs = CLIP.cond_stage_model.transformer( + input_ids = clip_pre["input_ids"].to(CLIP.load_device), + attention_mask = clip_mask.to(CLIP.load_device), + ) + # to-do: add hidden states + clip_embs = clip_outs[0].float().cpu() + + # combined cond + return ([[ + clip_embs, { + "context_t5": t5_embs, + "context_mask": clip_mask.float(), + "context_t5_mask": t5_mask.float() + } + ]],) + +class HYDiTTextEncodeSimple(HYDiTTextEncode): + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "text": ("STRING", {"multiline": True}), + "CLIP": ("CLIP",), + "T5": ("T5",), + } + } + + FUNCTION = "encode_simple" + TITLE = "Hunyuan DiT Text Encode (simple)" + + def encode_simple(self, text, **args): + return self.encode(text=text, text_t5=text, **args) + +NODE_CLASS_MAPPINGS = { + "HYDiTCheckpointLoader": HYDiTCheckpointLoader, + "HYDiTTextEncoderLoader": HYDiTTextEncoderLoader, + "HYDiTTextEncode": HYDiTTextEncode, + "HYDiTTextEncodeSimple": HYDiTTextEncodeSimple, +} diff --git a/HunYuanDiT/tenc.py b/HunYuanDiT/tenc.py new file mode 100644 index 0000000..7883982 --- /dev/null +++ b/HunYuanDiT/tenc.py @@ -0,0 +1,168 @@ +# This is for loading the CLIP (bert?) + mT5 encoder for HunYuanDiT +import os +import torch +from transformers import AutoTokenizer, modeling_utils +from transformers import T5Config, T5EncoderModel, BertConfig, BertModel + +import comfy.model_patcher +import comfy.utils + +class mT5Model(torch.nn.Module): + def __init__(self, textmodel_json_config=None, device="cpu", max_length=256, freeze=True, dtype=None): + super().__init__() + self.device = device + self.dtype = dtype + self.max_length = max_length + if textmodel_json_config is None: + textmodel_json_config = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + f"config_mt5.json" + ) + config = T5Config.from_json_file(textmodel_json_config) + with modeling_utils.no_init_weights(): + self.transformer = T5EncoderModel(config) + self.to(dtype) + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def load_sd(self, sd): + return self.transformer.load_state_dict(sd, strict=False) + + def to(self, *args, **kwargs): + self.transformer.to(*args, **kwargs) + +class hyCLIPModel(torch.nn.Module): + def __init__(self, textmodel_json_config=None, device="cpu", max_length=77, freeze=True, dtype=None): + super().__init__() + self.device = device + self.dtype = dtype + self.max_length = max_length + if textmodel_json_config is None: + textmodel_json_config = os.path.join( + os.path.dirname(os.path.realpath(__file__)), + f"config_clip.json" + ) + config = BertConfig.from_json_file(textmodel_json_config) + with modeling_utils.no_init_weights(): + self.transformer = BertModel(config) + self.to(dtype) + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def load_sd(self, sd): + return self.transformer.load_state_dict(sd, strict=False) + + def to(self, *args, **kwargs): + self.transformer.to(*args, **kwargs) + +class EXM_HyDiT_Tenc_Temp: + def __init__(self, no_init=False, device="cpu", dtype=None, model_class="mT5", *kwargs): + if no_init: + return + + if device == "auto": + size = 0 + self.load_device = model_management.text_encoder_device() + self.offload_device = model_management.text_encoder_offload_device() + self.init_device = "cpu" + elif device == "cpu": + size = 0 + self.load_device = "cpu" + self.offload_device = "cpu" + self.init_device="cpu" + elif device.startswith("cuda"): + print("Direct CUDA device override!\nVRAM will not be freed by default.") + size = 0 + self.load_device = device + self.offload_device = device + self.init_device = device + else: + size = 0 + self.load_device = model_management.get_torch_device() + self.offload_device = "cpu" + self.init_device="cpu" + + self.dtype = dtype + self.device = device + if model_class == "mT5": + self.cond_stage_model = mT5Model( + device = device, + dtype = dtype, + ) + tokenizer_args = {"subfolder": "t2i/mt5"} + else: + self.cond_stage_model = hyCLIPModel( + device = device, + dtype = dtype, + ) + tokenizer_args = {"subfolder": "t2i/tokenizer",} + self.tokenizer = AutoTokenizer.from_pretrained( + "Tencent-Hunyuan/HunyuanDiT", + **tokenizer_args + ) + self.patcher = comfy.model_patcher.ModelPatcher( + self.cond_stage_model, + load_device = self.load_device, + offload_device = self.offload_device, + current_device = self.load_device, + size = size, + ) + + def clone(self): + n = EXM_HyDiT_Tenc_Temp(no_init=True) + n.patcher = self.patcher.clone() + n.cond_stage_model = self.cond_stage_model + n.tokenizer = self.tokenizer + return n + + def load_sd(self, sd): + return self.cond_stage_model.load_sd(sd) + + def get_sd(self): + return self.cond_stage_model.state_dict() + + def load_model(self): + if self.load_device != "cpu": + model_management.load_model_gpu(self.patcher) + return self.patcher + + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): + return self.patcher.add_patches(patches, strength_patch, strength_model) + + def get_key_patches(self): + return self.patcher.get_key_patches() + +def load_clip(model_path, **kwargs): + model = EXM_HyDiT_Tenc_Temp(model_class="clip", **kwargs) + sd = comfy.utils.load_torch_file(model_path) + + prefix = "bert." + state_dict = {} + for key in sd: + nkey = key + if key.startswith(prefix): + nkey = key[len(prefix):] + state_dict[nkey] = sd[key] + + m, e = model.load_sd(state_dict) + if len(m) > 0 or len(e) > 0: + print(f"HYDiT: clip missing {len(m)} keys ({len(e)} extra)") + return model + +def load_t5(model_path, **kwargs): + model = EXM_HyDiT_Tenc_Temp(model_class="mT5", **kwargs) + sd = comfy.utils.load_torch_file(model_path) + m, e = model.load_sd(sd) + if len(m) > 0 or len(e) > 0: + print(f"HYDiT: mT5 missing {len(m)} keys ({len(e)} extra)") + return model diff --git a/__init__.py b/__init__.py index 38967a2..24517dc 100644 --- a/__init__.py +++ b/__init__.py @@ -22,6 +22,10 @@ from .T5.nodes import NODE_CLASS_MAPPINGS as T5_Nodes NODE_CLASS_MAPPINGS.update(T5_Nodes) + # HYDiT + from .HunYuanDiT.nodes import NODE_CLASS_MAPPINGS as HunYuanDiT_Nodes + NODE_CLASS_MAPPINGS.update(HunYuanDiT_Nodes) + # VAE from .VAE.nodes import NODE_CLASS_MAPPINGS as VAE_Nodes NODE_CLASS_MAPPINGS.update(VAE_Nodes) diff --git a/requirements.txt b/requirements.txt index 51c2340..bcd3529 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -timm==0.6.13 +timm>=0.6.13 sentencepiece>=0.1.97 transformers>=4.34.1 accelerate>=0.23.0 +einops>=0.6.0