Jiatao Gu, Shuangfei Zhai, Yizhe Zhang, Josh Susskind & Navdeep Jaitly , ICLR-2024
Generative models have made a big step forward with Matryoshka Diffusion Models (MDM), which specialize in creating sharp images and videos. MDM stands out by using a unique structure and approach, breaking away from the usual sequential or latent diffusion models. Unlike traditional models, MDM operates in pixel space instead of latent space, making it more user-friendly, especially with text representations. This advancement is crucial for tasks that need seamless high-resolution generation without the added complexity of multi-step training or inference processes
-
Multi-Resolution Diffusion : MDM involves a multi-resolution diffusion process that takes advantage of the hierarchical structure of data formation. This method enables simultaneous noise reduction across different resolutions, thanks to an innovative architecture known as Nested UNet. This architecture plays a crucial role in making the generation process efficient across scales, significantly improving optimization for creating high-resolution content.
-
Training : MDM introduces a noteworthy innovation by incorporating a progressive training schedule. This approach starts training with low-resolution models and gradually integrates higher-resolution inputs. This phased training not only enhances computational efficiency but also greatly improves model quality and convergence speed.
-
Emperical Validation : The model's performance is notably impressive, especially when it's trained on the CC12M dataset. It achieves competitive results in high-resolution synthesis without the typically large data requirements associated with other models, such as stable diffusion, which is trained on over 2 billion images.
-
Model in Extended Space: Unlike cascaded or latent methods, MDM learns a single diffusion process with hierarchical structure by introducing a multi-resolution diffusion process in an extended space.
Given a data point
$x\in \mathbb{R}^N$ , the time-dependent latent$z_t = \left[z^1_t,\ldots,z^R_t\right]\in \mathbb{R}^{N_1+\ldots N_R}$ is defined. For each$z_r, r=1,\ldots,R$ :
where
-
Nested U-Net: MDM, following the structure of UNet, incorporates skip-connections and computation blocks for preserving detailed input information. The progressive compression assumption in MDM naturally aligns computations for different resolutions, leading to the introduction of NestedUNet. This architecture efficiently shares computations across resolutions, simplifying the learning process for high-resolution generation. Compared to other hierarchical approaches, NestedUNet is both simpler and allows for optimal computation allocation, enhancing scalability
-
Loss Function: The normal denoising objective defined jointly for multiple resolutions minimizes the difference between the downsampled image
$D^r(\textbf{x})$ and$x_\theta$ which is the prediction of the image
MDM outperforms other alternatives, delivering superior results and achieving convergence more rapidly.
Samples from the model trained on CC12M at 1024 resolution
- The introduction of MDM steers the conversation towards the efficiency and scalability of diffusion models for high-resolution synthesis. Its architectural design and training methodology offers a finer understanding of how multi-resolution processing can be harnessed, suggesting a potential paradigm shift in generative model training.
- Further refinement of the Nested UNet architecture, exploration of alternative weight sharing mechanisms, and the integration of autoencoder-based approaches within the MDM framework are some of the potential areas for expansion.