diff --git a/README.md b/README.md index aac777a..f2b8cbf 100644 --- a/README.md +++ b/README.md @@ -76,13 +76,13 @@ To use our model, please follow the inference code in [inference.py](./inference #### For text-to-video generation: ```bash -python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED +python inference.py --ckpt_path 'PATH' --prompt "PROMPT" --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED ``` #### For image-to-video generation: ```bash -python inference.py --ckpt_dir 'PATH' --prompt "PROMPT" --input_image_path IMAGE_PATH --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED +python inference.py --ckpt_path 'PATH' --prompt "PROMPT" --input_image_path IMAGE_PATH --height HEIGHT --width WIDTH --num_frames NUM_FRAMES --seed SEED ``` ## ComfyUI Integration diff --git a/inference.py b/inference.py index ec82013..5e111cd 100644 --- a/inference.py +++ b/inference.py @@ -8,7 +8,7 @@ import imageio import numpy as np -import safetensors.torch +from safetensors import safe_open import torch import torch.nn.functional as F from PIL import Image @@ -29,34 +29,33 @@ MAX_NUM_FRAMES = 257 -def load_vae(vae_dir): - vae_ckpt_path = vae_dir / "vae_diffusion_pytorch_model.safetensors" - vae_config_path = vae_dir / "config.json" - with open(vae_config_path, "r") as f: - vae_config = json.load(f) +def load_vae(vae_config, ckpt): vae = CausalVideoAutoencoder.from_config(vae_config) - vae_state_dict = safetensors.torch.load_file(vae_ckpt_path) + vae_state_dict = { + key.replace("vae.", ""): value + for key, value in ckpt.items() + if key.startswith("vae.") + } vae.load_state_dict(vae_state_dict) if torch.cuda.is_available(): vae = vae.cuda() return vae.to(torch.bfloat16) -def load_unet(unet_dir): - unet_ckpt_path = unet_dir / "unet_diffusion_pytorch_model.safetensors" - unet_config_path = unet_dir / "config.json" - transformer_config = Transformer3DModel.load_config(unet_config_path) +def load_transformer(transformer_config, ckpt): transformer = Transformer3DModel.from_config(transformer_config) - unet_state_dict = safetensors.torch.load_file(unet_ckpt_path) - transformer.load_state_dict(unet_state_dict, strict=True) + transformer_state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in ckpt.items() + if key.startswith("model.diffusion_model.") + } + transformer.load_state_dict(transformer_state_dict, strict=True) if torch.cuda.is_available(): transformer = transformer.cuda() return transformer -def load_scheduler(scheduler_dir): - scheduler_config_path = scheduler_dir / "scheduler_config.json" - scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) +def load_scheduler(scheduler_config): return RectifiedFlowScheduler.from_config(scheduler_config) @@ -168,10 +167,10 @@ def main(): # Directories parser.add_argument( - "--ckpt_dir", + "--ckpt_path", type=str, required=True, - help="Path to the directory containing unet, vae, and scheduler subdirectories", + help="Path to a safetensors file that contains all model parts.", ) parser.add_argument( "--input_video_path", @@ -205,6 +204,12 @@ def main(): default=3, help="Guidance scale for the pipeline", ) + parser.add_argument( + "--image_cond_noise_scale", + type=float, + default=0.15, + help="Amount of noise to add to the conditioned image", + ) parser.add_argument( "--height", type=int, @@ -297,15 +302,22 @@ def main(): media_items = None # Paths for the separate mode directories - ckpt_dir = Path(args.ckpt_dir) - unet_dir = ckpt_dir / "unet" - vae_dir = ckpt_dir / "vae" - scheduler_dir = ckpt_dir / "scheduler" + ckpt_path = Path(args.ckpt_path) + ckpt = {} + with safe_open(ckpt_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + ckpt[k] = f.get_tensor(k) + + configs = json.loads(metadata["config"]) + vae_config = configs["vae"] + transformer_config = configs["transformer"] + scheduler_config = configs["scheduler"] # Load models - vae = load_vae(vae_dir) - unet = load_unet(unet_dir) - scheduler = load_scheduler(scheduler_dir) + vae = load_vae(vae_config, ckpt) + transformer = load_transformer(transformer_config, ckpt) + scheduler = load_scheduler(scheduler_config) patchifier = SymmetricPatchifier(patch_size=1) text_encoder = T5EncoderModel.from_pretrained( "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" @@ -316,12 +328,12 @@ def main(): "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer" ) - if args.bfloat16 and unet.dtype != torch.bfloat16: - unet = unet.to(torch.bfloat16) + if args.bfloat16 and transformer.dtype != torch.bfloat16: + transformer = transformer.to(torch.bfloat16) # Use submodels for the pipeline submodel_dict = { - "transformer": unet, + "transformer": transformer, "patchifier": patchifier, "text_encoder": text_encoder, "tokenizer": tokenizer, @@ -365,6 +377,7 @@ def main(): if media_items is not None else ConditioningMethod.UNCONDITIONAL ), + image_cond_noise_scale=args.image_cond_noise_scale, mixed_precision=not args.bfloat16, ).images diff --git a/ltx_video/pipelines/pipeline_ltx_video.py b/ltx_video/pipelines/pipeline_ltx_video.py index 119f700..f0e7832 100644 --- a/ltx_video/pipelines/pipeline_ltx_video.py +++ b/ltx_video/pipelines/pipeline_ltx_video.py @@ -655,6 +655,26 @@ def _clean_caption(self, caption): return caption.strip() + def image_cond_noise_update( + self, + t, + init_latents, + latents, + noise_scale, + conditiong_mask, + generator, + ): + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + latents = (init_latents + noise_scale * noise * (t**2)) * conditiong_mask[ + ..., None + ] + latents * (1 - conditiong_mask[..., None]) + return latents + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents( self, @@ -897,6 +917,7 @@ def __call__( self.video_scale_factor = self.video_scale_factor if is_video else 1 conditioning_method = kwargs.get("conditioning_method", None) vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False) + image_cond_noise_scale = kwargs.get("image_cond_noise_scale", 0.0) init_latents, conditioning_mask = self.prepare_conditioning( media_items, num_frames, @@ -924,6 +945,7 @@ def __call__( latents=init_latents, latents_mask=conditioning_mask, ) + orig_conditiong_mask = conditioning_mask if conditioning_mask is not None and is_video: assert num_images_per_prompt == 1 conditioning_mask = ( @@ -954,6 +976,15 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): + if conditioning_method == ConditioningMethod.FIRST_FRAME: + latents = self.image_cond_noise_update( + t, + init_latents, + latents, + image_cond_noise_scale, + orig_conditiong_mask, + generator, + ) latent_model_input = ( torch.cat([latents] * 2) if do_classifier_free_guidance else latents ) diff --git a/pyproject.toml b/pyproject.toml index 3f17249..30da2d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,6 @@ requires-python = ">=3.10" readme = "README.md" classifiers = [ "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", "Operating System :: OS Independent" ] dependencies = [