MiniSora-DiT, a DiT reproduction based on XTuner from the open source community MiniSora
English | 简体中文
👋 加入我们的 微信社区
Mini Sora 开源社区定位为由社区同学自发组织的开源社区(免费不收取任何费用、不割韭菜),Mini Sora 计划探索 Sora 的实现路径和后续的发展方向:
- 将定期举办 Sora 的圆桌和社区一起探讨可能性
- 视频生成的现有技术路径探讨
招募MiniSora社区同学使用 XTuner
复现 DiT
, 希望领取任务同学有如下特点:
- 熟悉
OpenMMLab MMEngine
机制 - 熟悉
DiT
DiT
作者和Sora
作者为同一个XTuner
现有能够高效训练1000K
序列长度的核心技术
- 算力提供 2*A100
- XTuner 核心开发者 P佬@pppppM 会大力支持~
XTuner: https://github.com/internLM/xtuner
- ImageNet-1K
可以在 OpenDataLab 进行下载 ImageNet-1K
pip install openxlab #安装
pip install -U openxlab #版本升级
openxlab login #进行登录,输入对应的AK/SK
cd ${dataset_dir}
openxlab dataset get --dataset-repo OpenDataLab/ImageNet-1K #数据集下载
目前已在 dev 分支提交了 DiT 在纯 torch 下的复现代码 fast-DiT,该版本使用了混合精度还有一些加速方案,可以极大程度降低显存,以及提升训练速度。
- 环境安装
使用 dev 分支中的 environment.yml
可以复现环境
conda env create -f environment.yml
conda activate DiT
- 数据集预处理
因为在原版 Meta 的 DiT 中,每个 iter 都会对数据进行重复计算,为了节省训练的时间,可以先对图片进行预处理,在训练的时候可以节省这部分的时间
详见 dev 分支中的 extract_features.py#L163 ,处理需要时间较久,大概 1~2小时。
for x, y in loader:
x = x.to(device)
y = y.to(device)
with torch.no_grad():
# Map input images to latent space + normalize latents:
x = vae.encode(x).latent_dist.sample().mul_(0.18215)
x = x.detach().cpu().numpy() # (1, 4, 32, 32)
np.save(f'{args.features_path}/imagenet256_features/{train_steps}.npy', x)
y = y.detach().cpu().numpy() # (1,)
np.save(f'{args.features_path}/imagenet256_labels/{train_steps}.npy', y)
train_steps += 1
print(train_steps)
执行后会对每个图片生成一个 npy 文件,训练的时候直接读取
- 使用 mmengine 重写数据流,下面是原版的 dataset,可见直接读取上一步生成的 npy 文件,省去了前处理时间
class CustomDataset(Dataset):
def __init__(self, features_dir, labels_dir):
self.features_dir = features_dir
self.labels_dir = labels_dir
self.features_files = sorted(os.listdir(features_dir))
self.labels_files = sorted(os.listdir(labels_dir))
def __len__(self):
assert len(self.features_files) == len(self.labels_files), \
"Number of feature files and label files should be same"
return len(self.features_files)
def __getitem__(self, idx):
feature_file = self.features_files[idx]
label_file = self.labels_files[idx]
features = np.load(os.path.join(self.features_dir, feature_file))
labels = np.load(os.path.join(self.labels_dir, label_file))
return torch.from_numpy(features), torch.from_numpy(labels)
- 重写 loss 计算
- 使用 xtuner 调训练 pipeline
我们非常希望你们能够为 Mini Sora 开源社区做出贡献,并且帮助我们把它做得比现在更好!
具体查看贡献指南