Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update: Version 0.1.2 #69

Merged
merged 8 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 70 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ This is the official repository for LTX-Video.
[Website](https://www.lightricks.com/ltxv) |
[Model](https://huggingface.co/Lightricks/LTX-Video) |
[Demo](https://fal.ai/models/fal-ai/ltx-video) |
[Paper (Soon)](https://github.com/Lightricks/LTX-Video)
[Paper (Soon)](https://github.com/Lightricks/LTX-Video)

</div>

Expand All @@ -20,7 +20,11 @@ This is the official repository for LTX-Video.
- [Installation](#installation)
- [Inference](#inference)
- [ComfyUI Integration](#comfyui-integration)
- [Diffusers Integration](#diffusers-integration)
- [Model User Guide](#model-user-guide)
- [Community Contribution](#community-contribution)
- [Training](#trining)
- [Join Us!](#join-us)
- [Acknowledgement](#acknowledgement)

# Introduction
Expand Down Expand Up @@ -60,13 +64,13 @@ source env/bin/activate
python -m pip install -e .\[inference-script\]
```

Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/LTX-Video)
Then, download the model from [Hugging Face](https://huggingface.co/Lightricks/LTX-Video)

```python
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download

model_path = 'PATH' # The local directory to save downloaded checkpoint
snapshot_download("Lightricks/LTX-Video", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
hf_hub_download(repo_id="Lightricks/LTX-Video", filename="ltx-video-2b-v0.9.1.safetensors", local_dir=model_path, local_dir_use_symlinks=False, repo_type='model')
```

### Inference
Expand Down Expand Up @@ -113,7 +117,68 @@ When writing prompts, focus on detailed, chronological descriptions of actions a
* Guidance Scale: 3-3.5 are the recommended values
* Inference Steps: More steps (40+) for quality, fewer steps (20-30) for speed

## More to come...
## Community Contribution

### ComfyUI-LTXTricks 🛠️

A community project providing additional nodes for enhanced control over the LTX Video model. It includes implementations of advanced techniques like RF-Inversion, RF-Edit, FlowEdit, and more. These nodes enable workflows such as Image and Video to Video (I+V2V), enhanced sampling via Spatiotemporal Skip Guidance (STG), and interpolation with precise frame settings.

- **Repository:** [ComfyUI-LTXTricks](https://github.com/logtd/ComfyUI-LTXTricks)
- **Features:**
- 🔄 **RF-Inversion:** Implements [RF-Inversion](https://rf-inversion.github.io/) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_inversion.json).
- ✂️ **RF-Edit:** Implements [RF-Solver-Edit](https://github.com/wangjiangshan0725/RF-Solver-Edit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_rf_edit.json).
- 🌊 **FlowEdit:** Implements [FlowEdit](https://github.com/fallenshock/FlowEdit) with an [example workflow here](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_flow_edit.json).
- 🎥 **I+V2V:** Enables Video to Video with a reference image. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_iv2v.json).
- ✨ **Enhance:** Partial implementation of [STGuidance](https://junhahyung.github.io/STGuidance/). [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltxv_stg.json).
- 🖼️ **Interpolation and Frame Setting:** Nodes for precise control of latents per frame. [Example workflow](https://github.com/logtd/ComfyUI-LTXTricks/blob/main/example_workflows/example_ltx_interpolation.json).


### LTX-VideoQ8 🎱

**LTX-VideoQ8** is an 8-bit optimized version of [LTX-Video](https://github.com/Lightricks/LTX-Video), designed for faster performance on NVIDIA ADA GPUs.

- **Repository:** [LTX-VideoQ8](https://github.com/KONAKONA666/LTX-Video)
- **Features:**
- 🚀 Up to 3X speed-up with no accuracy loss
- 🎥 Generate 720x480x121 videos in under a minute on RTX 4060 (8GB VRAM)
- 🛠️ Fine-tune 2B transformer models with precalculated latents
- **Community Discussion:** [Reddit Thread](https://www.reddit.com/r/StableDiffusion/comments/1h79ks2/fast_ltx_video_on_rtx_4060_and_other_ada_gpus/)

### Your Contribution

...is welcome! If you have a project or tool that integrates with LTX-Video,
please let us know by opening an issue or pull request.

# Training

## Diffusers

Diffusers implemented [LoRA support](https://github.com/huggingface/diffusers/pull/10228),
with a training script for fine-tuning.
More information and training script in
[finetrainers](https://github.com/a-r-r-o-w/finetrainers?tab=readme-ov-file#training).

## Diffusion-Pipe

An experimental training framework with pipeline parallelism, enabling fine-tuning of large models like **LTX-Video** across multiple GPUs.

- **Repository:** [Diffusion-Pipe](https://github.com/tdrussell/diffusion-pipe)
- **Features:**
- 🛠️ Full fine-tune support for LTX-Video using LoRA
- 📊 Useful metrics logged to Tensorboard
- 🔄 Training state checkpointing and resumption
- ⚡ Efficient pre-caching of latents and text embeddings for multi-GPU setups


# Join Us 🚀

Want to work on cutting-edge AI research and make a real impact on millions of users worldwide?

At **Lightricks**, an AI-first company, we’re revolutionizing how visual content is created.

If you are passionate about AI, computer vision, and video generation, we would love to hear from you!

Please visit our [careers page](https://careers.lightricks.com/careers?query=&office=all&department=R%26D) for more information.

# Acknowledgement

Expand Down
139 changes: 85 additions & 54 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import json
import os
import random
from datetime import datetime
Expand All @@ -8,7 +7,6 @@

import imageio
import numpy as np
from safetensors import safe_open
import torch
import torch.nn.functional as F
from PIL import Image
Expand All @@ -22,41 +20,18 @@
from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline
from ltx_video.schedulers.rf import RectifiedFlowScheduler
from ltx_video.utils.conditioning_method import ConditioningMethod

from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy

MAX_HEIGHT = 720
MAX_WIDTH = 1280
MAX_NUM_FRAMES = 257


def load_vae(vae_config, ckpt):
vae = CausalVideoAutoencoder.from_config(vae_config)
vae_state_dict = {
key.replace("vae.", ""): value
for key, value in ckpt.items()
if key.startswith("vae.")
}
vae.load_state_dict(vae_state_dict)
if torch.cuda.is_available():
vae = vae.cuda()
return vae.to(torch.bfloat16)


def load_transformer(transformer_config, ckpt):
transformer = Transformer3DModel.from_config(transformer_config)
transformer_state_dict = {
key.replace("model.diffusion_model.", ""): value
for key, value in ckpt.items()
if key.startswith("model.diffusion_model.")
}
transformer.load_state_dict(transformer_state_dict, strict=True)
def get_total_gpu_memory():
if torch.cuda.is_available():
transformer = transformer.cuda()
return transformer


def load_scheduler(scheduler_config):
return RectifiedFlowScheduler.from_config(scheduler_config)
total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
return total_memory
return None


def load_image_to_tensor_with_resize_and_crop(
Expand Down Expand Up @@ -204,6 +179,30 @@ def main():
default=3,
help="Guidance scale for the pipeline",
)
parser.add_argument(
"--stg_scale",
type=float,
default=1,
help="Spatiotemporal guidance scale for the pipeline. 0 to disable STG.",
)
parser.add_argument(
"--stg_rescale",
type=float,
default=0.7,
help="Spatiotemporal guidance rescaling scale for the pipeline. 1 to disable rescale.",
)
parser.add_argument(
"--stg_mode",
type=str,
default="stg_a",
help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.",
)
parser.add_argument(
"--stg_skip_layers",
type=str,
default="19",
help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.",
)
parser.add_argument(
"--image_cond_noise_scale",
type=float,
Expand Down Expand Up @@ -233,9 +232,24 @@ def main():
)

parser.add_argument(
"--bfloat16",
action="store_true",
help="Denoise in bfloat16",
"--precision",
choices=["bfloat16", "mixed_precision"],
default="bfloat16",
help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.",
)

# VAE noise augmentation
parser.add_argument(
"--decode_timestep",
type=float,
default=0.05,
help="Timestep for decoding noise",
)
parser.add_argument(
"--decode_noise_scale",
type=float,
default=0.025,
help="Noise level for decoding noise",
)

# Prompts
Expand All @@ -251,6 +265,12 @@ def main():
help="Negative prompt for undesired features",
)

parser.add_argument(
"--offload_to_cpu",
action="store_true",
help="Offloading unnecessary computations to CPU.",
)

logger = logging.get_logger(__name__)

args = parser.parse_args()
Expand All @@ -259,6 +279,8 @@ def main():

seed_everething(args.seed)

offload_to_cpu = False if not args.offload_to_cpu else get_total_gpu_memory() < 30

output_dir = (
Path(args.output_path)
if args.output_path
Expand Down Expand Up @@ -301,35 +323,36 @@ def main():
else:
media_items = None

# Paths for the separate mode directories
ckpt_path = Path(args.ckpt_path)
ckpt = {}
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
metadata = f.metadata()
for k in f.keys():
ckpt[k] = f.get_tensor(k)

configs = json.loads(metadata["config"])
vae_config = configs["vae"]
transformer_config = configs["transformer"]
scheduler_config = configs["scheduler"]

# Load models
vae = load_vae(vae_config, ckpt)
transformer = load_transformer(transformer_config, ckpt)
scheduler = load_scheduler(scheduler_config)
patchifier = SymmetricPatchifier(patch_size=1)
vae = CausalVideoAutoencoder.from_pretrained(ckpt_path)
transformer = Transformer3DModel.from_pretrained(ckpt_path)
scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path)

text_encoder = T5EncoderModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder"
)
if torch.cuda.is_available():
text_encoder = text_encoder.to("cuda")
patchifier = SymmetricPatchifier(patch_size=1)
tokenizer = T5Tokenizer.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer"
)

if args.bfloat16 and transformer.dtype != torch.bfloat16:
if torch.cuda.is_available():
transformer = transformer.cuda()
vae = vae.cuda()
text_encoder = text_encoder.cuda()

vae = vae.to(torch.bfloat16)
if args.precision == "bfloat16" and transformer.dtype != torch.bfloat16:
transformer = transformer.to(torch.bfloat16)
text_encoder = text_encoder.to(torch.bfloat16)

# Set spatiotemporal guidance
skip_block_list = [int(x.strip()) for x in args.stg_skip_layers.split(",")]
skip_layer_strategy = (
SkipLayerStrategy.Attention
if args.stg_mode.lower() == "stg_a"
else SkipLayerStrategy.Residual
)

# Use submodels for the pipeline
submodel_dict = {
Expand Down Expand Up @@ -362,6 +385,11 @@ def main():
num_inference_steps=args.num_inference_steps,
num_images_per_prompt=args.num_images_per_prompt,
guidance_scale=args.guidance_scale,
skip_layer_strategy=skip_layer_strategy,
skip_block_list=skip_block_list,
stg_scale=args.stg_scale,
do_rescaling=args.stg_rescale != 1,
rescaling_scale=args.stg_rescale,
generator=generator,
output_type="pt",
callback_on_step_end=None,
Expand All @@ -378,7 +406,10 @@ def main():
else ConditioningMethod.UNCONDITIONAL
),
image_cond_noise_scale=args.image_cond_noise_scale,
mixed_precision=not args.bfloat16,
decode_timestep=args.decode_timestep,
decode_noise_scale=args.decode_noise_scale,
mixed_precision=(args.precision == "mixed_precision"),
offload_to_cpu=offload_to_cpu,
).images

# Crop the padded images to the desired resolution and number of frames
Expand Down
Loading
Loading