-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathpredict.py
62 lines (53 loc) · 1.94 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import torch
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
from cog import BasePredictor, Input, Path
import imageio
MODEL_CACHE = "model-cache"
class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
self.pipe = DiffusionPipeline.from_pretrained(
"damo-vilab/text-to-video-ms-1.7b",
torch_dtype=torch.float16,
variant="fp16",
cache_dir=MODEL_CACHE,
local_files_only=True,
).to("cuda")
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
self.pipe.enable_model_cpu_offload()
self.pipe.enable_vae_slicing()
def predict(
self,
prompt: str = Input(
description="Input prompt", default="An astronaut riding a horse"
),
num_frames: int = Input(
description="Number of frames for the output video", default=16
),
num_inference_steps: int = Input(
description="Number of denoising steps", ge=1, le=500, default=50
),
fps: int = Input(description="fps for the output video", default=8),
seed: int = Input(
description="Random seed. Leave blank to randomize the seed", default=None
),
) -> Path:
if seed is None:
seed = int.from_bytes(os.urandom(2), "big")
print(f"Using seed: {seed}")
generator = torch.Generator("cuda").manual_seed(seed)
frames = self.pipe(
prompt,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
generator=generator,
).frames
out = "/tmp/out.mp4"
writer = imageio.get_writer(out, format="FFMPEG", fps=fps)
for frame in frames:
writer.append_data(frame)
writer.close()
return Path(out)