-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
copy models from dit repo and add docstring
- Loading branch information
Showing
12 changed files
with
2,593 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
# Modified from OpenAI's diffusion repos and Meta DiT | ||
# DiT: https://github.com/facebookresearch/DiT/tree/main | ||
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py | ||
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion | ||
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py | ||
|
||
from . import gaussian_diffusion as gd | ||
from .respace import SpacedDiffusion, space_timesteps | ||
|
||
|
||
def create_diffusion( | ||
timestep_respacing, | ||
noise_schedule="linear", | ||
use_kl=False, | ||
sigma_small=False, | ||
predict_xstart=False, | ||
learn_sigma=True, | ||
rescale_learned_sigmas=False, | ||
diffusion_steps=1000, | ||
): | ||
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) | ||
if use_kl: | ||
loss_type = gd.LossType.RESCALED_KL | ||
elif rescale_learned_sigmas: | ||
loss_type = gd.LossType.RESCALED_MSE | ||
else: | ||
loss_type = gd.LossType.MSE | ||
if timestep_respacing is None or timestep_respacing == "": | ||
timestep_respacing = [diffusion_steps] | ||
return SpacedDiffusion( | ||
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), | ||
betas=betas, | ||
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X), | ||
model_var_type=( | ||
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL) | ||
if not learn_sigma | ||
else gd.ModelVarType.LEARNED_RANGE | ||
), | ||
loss_type=loss_type | ||
# rescale_timesteps=rescale_timesteps, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
# Modified from OpenAI's diffusion repos and Meta DiT | ||
# DiT: https://github.com/facebookresearch/DiT/tree/main | ||
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py | ||
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion | ||
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py | ||
|
||
import numpy as np | ||
import torch as th | ||
|
||
|
||
def normal_kl(mean1, logvar1, mean2, logvar2): | ||
""" | ||
Compute the KL divergence between two gaussians. | ||
Shapes are automatically broadcasted, so batches can be compared to | ||
scalars, among other use cases. | ||
""" | ||
tensor = None | ||
for obj in (mean1, logvar1, mean2, logvar2): | ||
if isinstance(obj, th.Tensor): | ||
tensor = obj | ||
break | ||
assert tensor is not None, "at least one argument must be a Tensor" | ||
|
||
# Force variances to be Tensors. Broadcasting helps convert scalars to | ||
# Tensors, but it does not work for th.exp(). | ||
logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)] | ||
|
||
return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)) | ||
|
||
|
||
def approx_standard_normal_cdf(x): | ||
""" | ||
A fast approximation of the cumulative distribution function of the | ||
standard normal. | ||
""" | ||
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) | ||
|
||
|
||
def continuous_gaussian_log_likelihood(x, *, means, log_scales): | ||
""" | ||
Compute the log-likelihood of a continuous Gaussian distribution. | ||
:param x: the targets | ||
:param means: the Gaussian mean Tensor. | ||
:param log_scales: the Gaussian log stddev Tensor. | ||
:return: a tensor like x of log probabilities (in nats). | ||
""" | ||
centered_x = x - means | ||
inv_stdv = th.exp(-log_scales) | ||
normalized_x = centered_x * inv_stdv | ||
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) | ||
return log_probs | ||
|
||
|
||
def discretized_gaussian_log_likelihood(x, *, means, log_scales): | ||
""" | ||
Compute the log-likelihood of a Gaussian distribution discretizing to a | ||
given image. | ||
:param x: the target images. It is assumed that this was uint8 values, | ||
rescaled to the range [-1, 1]. | ||
:param means: the Gaussian mean Tensor. | ||
:param log_scales: the Gaussian log stddev Tensor. | ||
:return: a tensor like x of log probabilities (in nats). | ||
""" | ||
assert x.shape == means.shape == log_scales.shape | ||
centered_x = x - means | ||
inv_stdv = th.exp(-log_scales) | ||
plus_in = inv_stdv * (centered_x + 1.0 / 255.0) | ||
cdf_plus = approx_standard_normal_cdf(plus_in) | ||
min_in = inv_stdv * (centered_x - 1.0 / 255.0) | ||
cdf_min = approx_standard_normal_cdf(min_in) | ||
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) | ||
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) | ||
cdf_delta = cdf_plus - cdf_min | ||
log_probs = th.where( | ||
x < -0.999, | ||
log_cdf_plus, | ||
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), | ||
) | ||
assert log_probs.shape == x.shape | ||
return log_probs |
Oops, something went wrong.