Skip to content

Commit

Permalink
implement mlp broadcast (#150)
Browse files Browse the repository at this point in the history
* init commit (#3)

* update diffusion

* polish

* update structure

* embed

* update ref

* update ref

* udpate opensora init

* remove

* update

* update name

* fix v1

* update ignore

* update stdit2

* update v2 infer

* support v2 infer

* update script

* polish code

* update script

* update script and rename

* fix cross

* update code for temporal and cross

* update base script

* remove layernorm for infer

* update

* add readme (#4)

* update default setting (#5)

* Spatial attention (#6)

* update prompt

* update spatial

* update code control

* Merge (#7)

* update diffusion

* polish

* update structure

* embed

* update ref

* update ref

* udpate opensora init

* remove

* update

* update name

* fix v1

* update ignore

* update stdit2

* update v2 infer

* support v2 infer

* update script

* polish code

* update script

* update script and rename

* fix cross

* fix some bugs

* optimize attn

* add latte

* update

* update opensora sample

* remove useless arg

* Update save_dir path in opensora sample configs

* update test

* seed

* update opensora plan

* update opensora plan

* update yaml

* remove useless

* update opensora plan

* update doc

* update script

* update data

* update code

* update usage

* update

* add attn skip for latte and opensoraplan (#8)

* add block

* fix

* update latte

* update

* add latte args (#9)

* fix latte argparse (#10)

* add kw

* fix arg

* add eval

* update ignore

* Opensorav1.2 (#11)

* add opensorav1.2

* update final

* udpate mgr

* support opensora skip (#12)

* update opensora yaml

* add shard

* update opensora skip

* update config for other models (#13)

* finish opensora-plan and latte

* Sp (#14)

* update latte shard

* rename arg

* update license

* update lincense file

* update dynamic sp for opensoraplan and latte

* Update STDiT3Block and STDiT3 classes in stdit3.py

* finish opensora

* code refine

* code refine

* final code clean

* update skip_diffusion_timestep function to handle more than 1 step skip

* update skip config (#16)

* udpate range for all

* update opensora plan

* update code

* polish code

* polish print

* update default value

* update verbose

* finish latte eval generate function

* update gitignore

* finish eval
TODO:
BUG check whether center crop is correct
TODO original video size< generated videos, how to resize?

* update eval save

* update mlp skip

* update

* parse mlp skip

* update

* update

* update

* update test mlp mse

* update

* update

* update

* update skip s t

* update

* update

* update

* update

* update latte

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update finish skip function

* delete and change file name to clean code

* update

* update

* clean code finish

* delete mse code and clean code finish

* delete and clean code

* clean eval code

* delete asset

* update

* polish code

* polish

* update rflow

---------

Co-authored-by: Xuanlei Zhao <[email protected]>
Co-authored-by: Jinxiaolong1129 <[email protected]>
  • Loading branch information
3 people committed Aug 23, 2024
1 parent 5b50248 commit 994a348
Show file tree
Hide file tree
Showing 14 changed files with 637 additions and 130 deletions.
20 changes: 20 additions & 0 deletions configs/latte/sample_pab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,23 @@ cross_threshold: [100, 850]
cross_gap: 7
# diffusion_skip: True
# diffusion_skip_timestep: [1,1,1,0,0,0,0,0,0,0]


# mlp skip
mlp_skip: True

mlp_spatial_skip_config: {
720: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
640: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
560: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
480: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
400: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
}

mlp_temporal_skip_config: {
720: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
640: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
560: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
480: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
400: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
}
21 changes: 21 additions & 0 deletions configs/opensora/sample_pab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,24 @@ cross_threshold: [540, 930]
cross_gap: 6
# diffusion_skip: True
# diffusion_skip_timestep: [1,1,1,0,0,0,0,0,0,0]

# mlp
mlp_skip: True

mlp_spatial_skip_config: {
676: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
788: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
864: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
}

mlp_temporal_skip_config: {
676: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
788: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
864: {'block': [0, 1, 2, 3, 4], 'skip_count': 2},
}



# final update
# timesteps: [tensor([1000.], device='cuda:0'), tensor([992.9313], device='cuda:0'), tensor([985.4678], device='cuda:0'), tensor([977.5753], device='cuda:0'), tensor([969.2160], device='cuda:0'), tensor([960.3470], device='cuda:0'), tensor([950.9203], device='cuda:0'), tensor([940.8815], device='cuda:0'), tensor([930.1691], device='cuda:0'), tensor([918.7130], device='cuda:0'), tensor([906.4327], device='cuda:0'), tensor([893.2362], device='cuda:0'), tensor([879.0170], device='cuda:0'), tensor([863.6513], device='cuda:0'), tensor([846.9945], device='cuda:0'), tensor([828.8770], device='cuda:0'), tensor([809.0977], device='cuda:0'), tensor([787.4170], device='cuda:0'), tensor([763.5468], device='cuda:0'), tensor([737.1379], device='cuda:0'), tensor([707.7626], device='cuda:0'), tensor([674.8911], device='cuda:0'), tensor([637.8601], device='cuda:0'), tensor([595.8265], device='cuda:0'), tensor([547.7032], device='cuda:0'), tensor([492.0635], device='cuda:0'), tensor([426.9973], device='cuda:0'), tensor([349.8871], device='cuda:0'), tensor([257.0481], device='cuda:0'), tensor([143.1210], device='cuda:0')]
# timesteps: [tensor([1000.], device='cuda:0'), tensor([992.9313], device='cuda:0'), tensor([985.4678], device='cuda:0'), tensor([977.5753], device='cuda:0'), tensor([969.2160], device='cuda:0'), tensor([960.3470], device='cuda:0'), tensor([950.9203], device='cuda:0'), tensor([940.8815], device='cuda:0'), tensor([930.1691], device='cuda:0'), tensor([918.7130], device='cuda:0'), tensor([906.4327], device='cuda:0'), tensor([893.2362], device='cuda:0'), tensor([879.0170], device='cuda:0'), tensor([863.6513], device='cuda:0'), tensor([846.9945], device='cuda:0'), tensor([828.8770], device='cuda:0'), tensor([809.0977], device='cuda:0'), tensor([787.4170], device='cuda:0'), tensor([763.5468], device='cuda:0'), tensor([737.1379], device='cuda:0'), tensor([707.7626], device='cuda:0'), tensor([674.8911], device='cuda:0'), tensor([637.8601], device='cuda:0'), tensor([595.8265], device='cuda:0'), tensor([547.7032], device='cuda:0'), tensor([492.0635], device='cuda:0'), tensor([426.9973], device='cuda:0'), tensor([349.8871], device='cuda:0'), tensor([257.0481], device='cuda:0'), tensor([143.1210], device='cuda:0')]
18 changes: 18 additions & 0 deletions configs/opensora_plan/sample_221f.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model_path: LanguageBind/Open-Sora-Plan-v1.1.0
version: 221x512x512
num_frames: 221
height: 512
width: 512
cache_dir: "./cache_dir"
text_encoder_name: DeepFloyd/t5-v1_1-xxl
text_prompt: [
"Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.",
"A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors.",
"Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle. The art style is 3D and realistic, with a focus on lighting and texture. The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image.",
]
ae: CausalVAEModel_4x8x8
save_img_path: "./samples/opensora_plan"
fps: 24
guidance_scale: 7.5
num_sampling_steps: 150
enable_tiling: True
51 changes: 51 additions & 0 deletions configs/opensora_plan/sample_65f_pab.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,54 @@ cross_threshold: [100, 850]
cross_gap: 6
# diffusion_skip: True
# diffusion_skip_timestep: [3,3,3,0,0,0,0,0,0,0]


# mlp
mlp_skip: True

mlp_spatial_skip_config: {
738: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
714: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
690: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
666: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
642: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
618: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
594: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
570: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
546: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
522: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
498: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
474: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
450: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
426: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
}

mlp_temporal_skip_config: {
738: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
714: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
690: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
666: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
642: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
618: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
594: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
570: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
546: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
522: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
498: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
474: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
450: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
426: {'block': [0, 1, 2, 3, 4, 5, 6], 'skip_count': 2},
}

# Time step tensor([894, 891, 891, 888, 888, 885, 885, 882, 882, 879, 879, 876, 876, 870,
# 864, 858, 852, 846, 840, 834, 828, 822, 816, 810, 804, 798, 792, 786,
# 780, 774, 768, 762, 756, 750, 744, 738, 732, 726, 720, 714, 708, 702,
# 696, 690, 684, 678, 672, 666, 660, 654, 648, 642, 636, 630, 624, 618,
# 612, 606, 600, 594, 588, 582, 576, 570, 564, 558, 552, 546, 540, 534,
# 528, 522, 516, 510, 504, 498, 492, 486, 480, 474, 468, 462, 456, 450,
# 444, 438, 432, 426, 420, 414, 408, 402, 396, 390, 384, 378, 372, 366,
# 360, 354, 348, 342, 336, 330, 324, 318, 312, 306, 300, 294, 288, 282,
# 276, 270, 264, 258, 252, 246, 240, 234, 228, 222, 216, 210, 204, 198,
# 192, 186, 180, 174, 168, 162, 156, 150, 144, 138, 132, 126, 120, 114,
# 108, 102, 96, 90, 84, 78, 72, 66, 60, 54, 48, 42, 36, 30,
# 24, 18, 12, 6, 0], device='cuda:0')
107 changes: 107 additions & 0 deletions opendit/core/pab_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(
diffusion_skip: bool,
diffusion_timestep_respacing: list,
diffusion_skip_timestep: list,
mlp_skip: bool,
mlp_spatial_skip_config: dict,
mlp_temporal_skip_config: dict,
):
self.steps = steps

Expand All @@ -43,6 +46,13 @@ def __init__(
self.diffusion_timestep_respacing = diffusion_timestep_respacing
self.diffusion_skip_timestep = diffusion_skip_timestep

self.mlp_skip = mlp_skip
self.mlp_spatial_skip_config = mlp_spatial_skip_config
self.mlp_temporal_skip_config = mlp_temporal_skip_config

self.temporal_mlp_outputs = {}
self.spatial_mlp_outputs = {}


class PABManager:
def __init__(self, config: PABConfig):
Expand Down Expand Up @@ -93,6 +103,89 @@ def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
count = (count + 1) % self.config.steps
return flag, count

@staticmethod
def _is_t_in_skip_config(all_timesteps, timestep, config):
is_t_in_skip_config = False
for key in config:
index = all_timesteps.index(key)
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
if timestep in skip_range:
is_t_in_skip_config = True
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
break
return is_t_in_skip_config, skip_range

def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if is_temporal:
cur_config = self.config.mlp_temporal_skip_config
else:
cur_config = self.config.mlp_spatial_skip_config

is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
next_flag = False
if (
self.config.mlp_skip
and (timestep is not None)
and (timestep in cur_config)
and (block_idx in cur_config[timestep]["block"])
):
flag = False
next_flag = True
count = count + 1
elif (
self.config.mlp_skip
and (timestep is not None)
and (is_t_in_skip_config)
and (block_idx in cur_config[skip_range[0]]["block"])
):
flag = True
count = 0
else:
flag = False

return flag, count, next_flag, skip_range

def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
if is_temporal:
self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output
else:
self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output

def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
skip_start_t = skip_range[0]
if is_temporal:
skip_output = (
self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None)
if self.config.temporal_mlp_outputs is not None
else None
)
else:
skip_output = (
self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None)
if self.config.spatial_mlp_outputs is not None
else None
)

if skip_output is not None:
if timestep == skip_range[-1]:
# TODO: save memory
if is_temporal:
del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)]
else:
del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)]
else:
raise ValueError(
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
)

return skip_output

def get_spatial_mlp_outputs(self):
return self.config.spatial_mlp_outputs

def get_temporal_mlp_outputs(self):
return self.config.temporal_mlp_outputs


def set_pab_manager(config: PABConfig):
global PAB_MANAGER
Expand Down Expand Up @@ -132,6 +225,20 @@ def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)


def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
if not enable_pab():
return False, count
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)


def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)


def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)


def get_diffusion_skip():
return enable_pab() and PAB_MANAGER.config.diffusion_skip

Expand Down
2 changes: 0 additions & 2 deletions opendit/models/latte/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .latte_t2v import LatteT2V
from .pipeline import LattePipeline

__all__ = [
"LatteT2V",
"LattePipeline",
]
Loading

0 comments on commit 994a348

Please sign in to comment.