Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enabled fully sharded decoding #181

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion videosys/core/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
# skip if only one rank involved
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
if world_size == 1:
if world_size == 1 or input_.size(dim) < world_size:
return input_

if pad > 0:
Expand Down Expand Up @@ -418,3 +418,122 @@ def all_to_all_with_pad(
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)

return input_

# ======================================================
# Halo Exchange
# ======================================================


def _halo_exchange_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
# skip if only one rank involved
if input_.size(dim) < dist.get_world_size(pg):
return input_
world_size = dist.get_world_size(pg)
rank = dist.get_rank(pg)
rank_list = dist.get_process_group_ranks(pg)
input_.shape[dim] // world_size

dst_l = (rank - 1) % world_size
dst_r = (rank + 1) % world_size

send_l = input_.narrow(dim, 0, pad).contiguous()
send_r = input_.narrow(dim, input_.size(dim) - pad, pad).contiguous()
recv_l = torch.zeros_like(send_l)
recv_r = torch.zeros_like(send_r)

is_odd = rank % 2 == 1
dst_l = rank_list[dst_l]
dst_r = rank_list[dst_r]
if is_odd:
dist.send(send_l, dst_l, group=pg)
dist.send(send_r, dst_r, group=pg)
else:
dist.recv(recv_r, dst_r, group=pg)
dist.recv(recv_l, dst_l, group=pg)
if is_odd:
dist.recv(recv_r, dst_r, group=pg)
dist.recv(recv_l, dst_l, group=pg)
else:
dist.send(send_l, dst_l, group=pg)
dist.send(send_r, dst_r, group=pg)

if rank == 0:
output = torch.cat([input_, recv_r], dim=dim)
elif rank == world_size - 1:
output = torch.cat([recv_l, input_], dim=dim)
else:
output = torch.cat([recv_l, input_, recv_r], dim=dim)

return output


class _HaloExchange(torch.autograd.Function):
"""
Halo exchange.

Args:
input_: input matrix.
process_group: process group.
dim: dimension
pad: padding size
"""

@staticmethod
def symbolic(graph, input_):
return _halo_exchange_func(input_)

@staticmethod
def forward(ctx, input_, process_group, dim, pad):
ctx.process_group = process_group
ctx.dim = dim
ctx.pad = pad
return _halo_exchange_func(input_, process_group, dim, pad)

@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError("Halo exchange does not support backward now.")


def halo_exchange(input_, process_group, dim, pad):
return _HaloExchange.apply(input_, process_group, dim, pad)


# ======================================================
# All Reduce
# ======================================================


def _all_reduce_func(input_, pg: dist.ProcessGroup, op: dist.ReduceOp):
dist.get_world_size(pg)
dist.get_rank(pg)
dist.all_reduce(input_, op=op, group=pg)
return input_


class _AllReduce(torch.autograd.Function):
"""
All reduce.

Args:
input_: input matrix.
process_group: process group.
op: reduce operation
"""

@staticmethod
def symbolic(graph, input_):
return _all_reduce_func(input_)

@staticmethod
def forward(ctx, input_, process_group, op):
ctx.process_group = process_group
ctx.op = op
return _all_reduce_func(input_, process_group, op)

@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError("All reduce does not support backward now.")


def all_reduce(input_, process_group, op=dist.ReduceOp.SUM):
return _AllReduce.apply(input_, process_group, op)
126 changes: 124 additions & 2 deletions videosys/models/autoencoders/autoencoder_kl_open_sora.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
from einops import rearrange
from transformers import PretrainedConfig, PreTrainedModel

from videosys.utils.vae_utils import _replace_conv_fwd, _replace_groupnorm_fwd, _replace_conv_opensora_fwd, dynamic_switch
from videosys.core.parallel_mgr import enable_sequence_parallel, get_sequence_parallel_group, get_sequence_parallel_size, get_sequence_parallel_rank
from videosys.core.comm import split_sequence, gather_sequence

class DiagonalGaussianDistribution(object):
def __init__(
Expand Down Expand Up @@ -119,6 +121,10 @@ def __init__(
dilation = (dilation, 1, 1)
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)

def set_sequence_parallel(self):
_replace_conv_fwd(self.conv)
_replace_conv_opensora_fwd(self)

def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
x = self.conv(x)
Expand Down Expand Up @@ -152,6 +158,14 @@ def __init__(
else:
self.conv3 = conv_fn(in_channels, self.filters, kernel_size=(1, 1, 1), bias=False)

def set_sequence_parallel(self):
_replace_groupnorm_fwd(self.norm1)
_replace_groupnorm_fwd(self.norm2)
self.conv1.set_sequence_parallel()
self.conv2.set_sequence_parallel()
if self.in_channels != self.filters:
self.conv3.set_sequence_parallel()

def forward(self, x):
residual = x
x = self.norm1(x)
Expand Down Expand Up @@ -256,6 +270,19 @@ def __init__(

self.conv2 = self.conv_fn(prev_filters, self.embedding_dim, kernel_size=(1, 1, 1), padding="same")

def set_sequence_parallel(self):
self.conv_in.set_sequence_parallel()
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
self.block_res_blocks[i][j].set_sequence_parallel()
if i < self.num_blocks - 1:
if isinstance(self.conv_blocks[i], CausalConv3d):
self.conv_blocks[i].set_sequence_parallel()
for i in range(self.num_res_blocks):
self.res_blocks[i].set_sequence_parallel()
_replace_groupnorm_fwd(self.norm1)
self.conv2.set_sequence_parallel()

def forward(self, x):
x = self.conv_in(x)

Expand Down Expand Up @@ -353,6 +380,19 @@ def __init__(

self.conv_out = self.conv_fn(filters, in_out_channels, 3)

def set_sequence_parallel(self):
self.conv1.set_sequence_parallel()
for i in range(self.num_res_blocks):
self.res_blocks[i].set_sequence_parallel()
for i in range(self.num_blocks):
for j in range(self.num_res_blocks):
self.block_res_blocks[i][j].set_sequence_parallel()
if i > 0:
if isinstance(self.conv_blocks[i - 1], CausalConv3d):
self.conv_blocks[i - 1].set_sequence_parallel()
_replace_groupnorm_fwd(self.norm1)
self.conv_out.set_sequence_parallel()

def forward(self, x):
x = self.conv1(x)
for i in range(self.num_res_blocks):
Expand Down Expand Up @@ -439,6 +479,23 @@ def get_latent_size(self, input_size):
latent_size.append(lsize)
return latent_size

def get_video_size(self, latent_size):
video_size = []
for i in range(3):
if latent_size[i] is None:
vsize = None
elif i == 0:
time_padding = (
0
if (latent_size[i] % self.time_downsample_factor == 0)
else self.time_downsample_factor - latent_size[i] % self.time_downsample_factor
)
vsize = latent_size[i] * self.patch_size[i] - time_padding
else:
vsize = latent_size[i] * self.patch_size[i]
video_size.append(vsize)
return video_size

def encode(self, x):
time_padding = (
0
Expand Down Expand Up @@ -546,6 +603,15 @@ def get_latent_size(self, input_size):
# ), "Input size must be divisible by patch size"
latent_size.append(input_size[i] // self.patch_size[i] if input_size[i] is not None else None)
return latent_size

def get_video_size(self, latent_size):
video_size = []
for i in range(3):
# assert (
# latent_size[i] is None or latent_size[i] % self.patch_size[i] == 0
# ), "Latent size must be divisible by patch size"
video_size.append(latent_size[i] * self.patch_size[i] if latent_size[i] is not None else None)
return video_size

@property
def device(self):
Expand Down Expand Up @@ -650,10 +716,27 @@ def __init__(self, config: VideoAutoencoderPipelineConfig):
shift = shift[None, :, None, None, None]
self.register_buffer("scale", scale)
self.register_buffer("shift", shift)
if enable_sequence_parallel():
self.set_sequence_parallel()

def set_sequence_parallel(self):
self.temporal_vae.encoder.set_sequence_parallel()
self.temporal_vae.decoder.set_sequence_parallel()
self.temporal_vae.quant_conv.set_sequence_parallel()
self.temporal_vae.post_quant_conv.set_sequence_parallel()

def encode(self, x):
if enable_sequence_parallel():
padding_f = x.shape[2] % get_sequence_parallel_size()
x = F.pad(x, (0, 0, 0, 0, 0, padding_f))
x = split_sequence(x, get_sequence_parallel_group(), dim=2)
x_z = self.spatial_vae.encode(x)

padding_s = 0
if enable_sequence_parallel():
padding_s = x_z.shape[4] % get_sequence_parallel_size()
x_z = F.pad(x_z, (0, padding_s, 0, 0, 0, 0))
x_z = dynamic_switch(x_z, True, 2, 4)
x_z = x_z.narrow(2, 0, x_z.shape[2] - padding_f)
if self.micro_frame_size is None:
posterior = self.temporal_vae.encode(x_z)
z = posterior.sample()
Expand All @@ -664,6 +747,10 @@ def encode(self, x):
posterior = self.temporal_vae.encode(x_z_bs)
z_list.append(posterior.sample())
z = torch.cat(z_list, dim=2)

if enable_sequence_parallel():
z = gather_sequence(z, get_sequence_parallel_group(), dim=4)
z = z.narrow(4, 0, z.shape[4] - padding_s)

if self.cal_loss:
return z, posterior, x_z
Expand All @@ -674,8 +761,18 @@ def decode(self, z, num_frames=None):
if not self.cal_loss:
z = z * self.scale.to(z.dtype) + self.shift.to(z.dtype)

if enable_sequence_parallel():
padding_s = z.shape[4] % get_sequence_parallel_size()
expected_s = self.get_video_size(z.shape[2:])[2]
z = F.pad(z, (0, padding_s, 0, 0, 0, 0))
z = split_sequence(z, get_sequence_parallel_group(), dim=4)

if self.micro_frame_size is None:
x_z = self.temporal_vae.decode(z, num_frames=num_frames)
if enable_sequence_parallel():
padding_f = x_z.shape[2] % get_sequence_parallel_size()
x_z = F.pad(x_z, (0, 0, 0, 0, 0, padding_f))
x_z = dynamic_switch(x_z, False, 2, 4) if enable_sequence_parallel() else x_z
x = self.spatial_vae.decode(x_z)
else:
x_z_list = []
Expand All @@ -685,8 +782,20 @@ def decode(self, z, num_frames=None):
x_z_list.append(x_z_bs)
num_frames -= self.micro_frame_size
x_z = torch.cat(x_z_list, dim=2)
if enable_sequence_parallel():
padding_f = x_z.shape[2] % get_sequence_parallel_size()
x_z = F.pad(x_z, (0, 0, 0, 0, 0, padding_f))
x_z = dynamic_switch(x_z, False, 2, 4) if enable_sequence_parallel() else x_z
x = self.spatial_vae.decode(x_z)

if enable_sequence_parallel():
x = gather_sequence(x, get_sequence_parallel_group(), dim=2)
x_z = gather_sequence(x_z, get_sequence_parallel_group(), dim=2)
x = x.narrow(2, 0, x.shape[2] - padding_f)
x = x.narrow(4, 0, min(expected_s, x.shape[4]))
x_z = x_z.narrow(2, 0, x_z.shape[2] - padding_f)
x_z = x_z.narrow(4, 0, x_z.shape[4] - padding_s)

if self.cal_loss:
return x, x_z
else:
Expand All @@ -710,6 +819,19 @@ def get_latent_size(self, input_size):
remain_size = self.temporal_vae.get_latent_size(remain_temporal_size)
sub_latent_size[0] += remain_size[0]
return sub_latent_size

def get_video_size(self, latent_size):
if self.micro_frame_size is None or latent_size[0] is None:
return self.spatial_vae.get_video_size(self.temporal_vae.get_video_size(latent_size))
else:
sub_latent_size = [self.micro_z_frame_size, latent_size[1], latent_size[2]]
sub_video_size = self.spatial_vae.get_video_size(self.temporal_vae.get_video_size(sub_latent_size))
sub_video_size[0] = sub_video_size[0] * (latent_size[0] // self.micro_z_frame_size)
remain_temporal_size = [latent_size[0] % self.micro_z_frame_size, None, None]
if remain_temporal_size[0] > 0:
remain_size = self.spatial_vae.get_video_size(self.temporal_vae.get_video_size(remain_temporal_size))
sub_video_size[0] += remain_size[0]
return sub_video_size

def get_temporal_last_layer(self):
return self.temporal_vae.decoder.conv_out.conv.weight
Expand Down
Loading