Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes for sample T2V #227

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions opensora/sample/pipeline_videogen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,8 @@
from typing import Callable, List, Optional, Tuple, Union

import torch
import einops
from einops import rearrange
from transformers import T5EncoderModel, T5Tokenizer

from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, Transformer2DModel
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import (
Expand Down Expand Up @@ -501,8 +498,18 @@ def _clean_caption(self, caption):
return caption.strip()

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator,
latents=None):
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
video_length: int,
height: int,
width: int,
dtype: torch.dtype,
device: Union[str, torch.device],
generator: Optional[torch.Generator],
latents: Optional[torch.FloatTensor]=None
):
shape = (
batch_size, num_channels_latents, video_length, self.vae.latent_size[0], self.vae.latent_size[1])
if isinstance(generator, list) and len(generator) != batch_size:
Expand Down Expand Up @@ -750,7 +757,7 @@ def __call__(

return VideoPipelineOutput(video=video)

def decode_latents(self, latents):
def decode_latents(self, latents: torch.FloatTensor):
video = self.vae.decode(latents)
# video = self.vae.decode(latents / 0.18215)
# video = rearrange(video, 'b c t h w -> b t c h w').contiguous()
Expand Down
231 changes: 141 additions & 90 deletions opensora/sample/sample_t2v.py
Original file line number Diff line number Diff line change
@@ -1,101 +1,170 @@
import math
import os
import torch
import argparse
import torchvision
import os, sys
from typing import List, Union

import imageio
import torch
from torchvision.utils import save_image
from diffusers.schedulers import (DDIMScheduler, DDPMScheduler, PNDMScheduler,
EulerDiscreteScheduler, DPMSolverMultistepScheduler,
HeunDiscreteScheduler, EulerAncestralDiscreteScheduler,
DEISMultistepScheduler, KDPM2AncestralDiscreteScheduler)
from diffusers.schedulers.scheduling_dpmsolver_singlestep import DPMSolverSinglestepScheduler
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from omegaconf import OmegaConf
from torchvision.utils import save_image
from transformers import T5EncoderModel, T5Tokenizer, AutoTokenizer

import os, sys
from transformers import T5EncoderModel, T5Tokenizer

from opensora.models.ae import ae_stride_config, getae, getae_wrapper
from opensora.models.ae.videobase import CausalVQVAEModelWrapper, CausalVAEModelWrapper
from opensora.models.ae import ae_stride_config, getae_wrapper
from opensora.models.diffusion.latte.modeling_latte import LatteT2V
from opensora.models.text_encoder import get_text_enc
from opensora.utils.utils import save_video_grid

sys.path.append(os.path.split(sys.path[0])[0])
from pipeline_videogen import VideoGenPipeline

import imageio


def main(args):
# torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"

vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir').to(device, dtype=torch.float16)
def get_models(args: argparse.Namespace, device: str):
vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", ).to(device, dtype=torch.float16)
if args.enable_tiling:
vae.vae.enable_tiling()
vae.vae.tile_overlap_factor = args.tile_overlap_factor

# Load model:
transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, cache_dir="cache_dir", torch_dtype=torch.float16).to(device)
transformer_model = LatteT2V.from_pretrained(args.model_path, subfolder=args.version, torch_dtype=torch.float16).to(device)
transformer_model.force_images = args.force_images
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, cache_dir="cache_dir")
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, cache_dir="cache_dir", torch_dtype=torch.float16).to(device)

video_length, image_size = transformer_model.config.video_length, int(args.version.split('x')[1])
latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2])
vae.latent_size = latent_size
if args.force_images:
video_length = 1
ext = 'jpg'
else:
ext = 'mp4'
tokenizer = T5Tokenizer.from_pretrained(args.text_encoder_name, )
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_name, torch_dtype=torch.float16).to(device)

# set eval mode
transformer_model.eval()
vae.eval()
text_encoder.eval()

return transformer_model, vae, text_encoder, tokenizer


def get_scheduler(sample_method: str):
schedulers = {
'DDIM': DDIMScheduler(),
'EulerDiscrete': EulerDiscreteScheduler(),
'DDPM': DDPMScheduler(),
'DPMSolverMultistep': DPMSolverMultistepScheduler(),
'DPMSolverSinglestep': DPMSolverSinglestepScheduler(),
'PNDM': PNDMScheduler(),
'HeunDiscrete': HeunDiscreteScheduler(),
'EulerAncestralDiscrete': EulerAncestralDiscreteScheduler(),
'DEISMultistep': DEISMultistepScheduler(),
'KDPM2AncestralDiscrete': KDPM2AncestralDiscreteScheduler()
}
return schedulers[sample_method]


def get_text_prompt(text_prompt: Union[List[str], str]):
if not isinstance(text_prompt, list):
text_prompt = [text_prompt]

if len(text_prompt) == 1 and text_prompt[0].endswith('txt'):
text_prompt = open(text_prompt[0], 'r').readlines()
text_prompt = [i.strip() for i in text_prompt]
return text_prompt


def save_video(videos: torch.FloatTensor, prompt: str, args: argparse.Namespace):
"""
Save a single video (output of pipeline).
"""
# Save results
try:
if args.force_images:
videos = videos[:, 0].permute(0, 3, 1, 2) # b t h w c -> b c h w
save_image(
videos / 255.0,
os.path.join(
args.save_img_path,
prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{args.ext}'),
nrow=1, normalize=True, value_range=(0, 1)) # t c h w

else:
imageio.mimwrite(
os.path.join(
args.save_img_path,
prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{args.ext}'
),
videos[0],
fps=args.fps,
quality=9) # highest quality is 10, lowest is 0
except:
print('Error when saving {}'.format(prompt))

return videos


def save_grid(video_grids: List[torch.FloatTensor], args: argparse.Namespace):
video_grids = torch.cat(video_grids, dim=0)

# Save results
# torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6)
if args.force_images:
save_image(
video_grids / 255.0,
os.path.join(
args.save_img_path,
f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{args.ext}'),
nrow=math.ceil(math.sqrt(len(video_grids))),
normalize=True, value_range=(0, 1))
else:
video_grids = save_video_grid(video_grids)
imageio.mimwrite(
os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{args.ext}'),
video_grids,
fps=args.fps,
quality=9)

if args.sample_method == 'DDIM': #########
scheduler = DDIMScheduler()
elif args.sample_method == 'EulerDiscrete':
scheduler = EulerDiscreteScheduler()
elif args.sample_method == 'DDPM': #############
scheduler = DDPMScheduler()
elif args.sample_method == 'DPMSolverMultistep':
scheduler = DPMSolverMultistepScheduler()
elif args.sample_method == 'DPMSolverSinglestep':
scheduler = DPMSolverSinglestepScheduler()
elif args.sample_method == 'PNDM':
scheduler = PNDMScheduler()
elif args.sample_method == 'HeunDiscrete': ########
scheduler = HeunDiscreteScheduler()
elif args.sample_method == 'EulerAncestralDiscrete':
scheduler = EulerAncestralDiscreteScheduler()
elif args.sample_method == 'DEISMultistep':
scheduler = DEISMultistepScheduler()
elif args.sample_method == 'KDPM2AncestralDiscrete': #########
scheduler = KDPM2AncestralDiscreteScheduler()
print('save path {}'.format(args.save_img_path))

# save_videos_grid(video, f"./{prompt}.gif")


def main(args):
# torch.manual_seed(args.seed)
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
print('videogen_pipeline', device)


# Prepare models and pipeline
transformer_model, vae, text_encoder, tokenizer = get_models(args, device)
scheduler = get_scheduler(args.sample_method)
videogen_pipeline = VideoGenPipeline(vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer_model).to(device=device)
transformer=transformer_model,
)
# Some pipeline configs
# videogen_pipeline.enable_sequential_cpu_offload()
# videogen_pipeline.enable_xformers_memory_efficient_attention()


# Prepare
video_length, image_size = transformer_model.config.video_length, int(args.version.split('x')[1])
latent_size = (image_size // ae_stride_config[args.ae][1], image_size // ae_stride_config[args.ae][2])
vae.latent_size = latent_size
if args.force_images:
video_length = 1
args.ext = 'jpg'
else:
args.ext = 'mp4'

if not os.path.exists(args.save_img_path):
os.makedirs(args.save_img_path)


# Get text prompts
text_prompt = get_text_prompt(args.text_prompt)


# Video generation
video_grids = []
if not isinstance(args.text_prompt, list):
args.text_prompt = [args.text_prompt]
if len(args.text_prompt) == 1 and args.text_prompt[0].endswith('txt'):
text_prompt = open(args.text_prompt[0], 'r').readlines()
args.text_prompt = [i.strip() for i in text_prompt]
for prompt in args.text_prompt:
for prompt in text_prompt:
print('Processing the ({}) prompt'.format(prompt))
videos = videogen_pipeline(prompt,
video_length=video_length,
Expand All @@ -107,37 +176,18 @@ def main(args):
num_images_per_prompt=1,
mask_feature=True,
).video
try:
if args.force_images:
videos = videos[:, 0].permute(0, 3, 1, 2) # b t h w c -> b c h w
save_image(videos / 255.0, os.path.join(args.save_img_path,
prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'),
nrow=1, normalize=True, value_range=(0, 1)) # t c h w

else:
imageio.mimwrite(
os.path.join(
args.save_img_path,
prompt.replace(' ', '_')[:100] + f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'
), videos[0],
fps=args.fps, quality=9) # highest quality is 10, lowest is 0
except:
print('Error when saving {}'.format(prompt))
video_grids.append(videos)
video_grids = torch.cat(video_grids, dim=0)


# torchvision.io.write_video(args.save_img_path + '_%04d' % args.run_time + '-.mp4', video_grids, fps=6)
if args.force_images:
save_image(video_grids / 255.0, os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'),
nrow=math.ceil(math.sqrt(len(video_grids))), normalize=True, value_range=(0, 1))
else:
video_grids = save_video_grid(video_grids)
imageio.mimwrite(os.path.join(args.save_img_path, f'{args.sample_method}_gs{args.guidance_scale}_s{args.num_sampling_steps}.{ext}'), video_grids, fps=args.fps, quality=9)

print('save path {}'.format(args.save_img_path))

# save_videos_grid(video, f"./{prompt}.gif")

videos = save_video(videos, prompt, args)

# Save result
if args.save_grid:
video_grids.append(videos)


# Save results
if args.save_grid:
save_grid(video_grids, args)



if __name__ == "__main__":
Expand All @@ -156,6 +206,7 @@ def main(args):
parser.add_argument('--force_images', action='store_true')
parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
parser.add_argument('--enable_tiling', action='store_true')
parser.add_argument('--save_grid', action='store_true', help='Save all prompts in a grid')
args = parser.parse_args()

main(args)