Skip to content

Commit

Permalink
Fix check that sigma_min and sigma_max > 0 in DPM-Solver
Browse files Browse the repository at this point in the history
  • Loading branch information
crowsonkb committed Sep 30, 2022
1 parent b72536d commit 91decae
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions k_diffusion/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,6 @@ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.):
def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81):
if order not in {2, 3}:
raise ValueError('order should be 2 or 3')
if t_start == 0 or t_end == 0:
raise ValueError('t_start and t_end should not be 0')
forward = t_end > t_start
if forward and h_init <= 0:
raise ValueError('For forward ODE integration, h_init must be positive')
Expand Down Expand Up @@ -400,6 +398,8 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078
@torch.no_grad()
def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1.):
"""DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(total=n, disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
Expand All @@ -410,6 +410,8 @@ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback
@torch.no_grad()
def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, return_info=False):
"""DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927."""
if sigma_min <= 0 or sigma_max <= 0:
raise ValueError('sigma_min and sigma_max must not be 0')
with tqdm(disable=disable) as pbar:
dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update)
if callback is not None:
Expand Down

0 comments on commit 91decae

Please sign in to comment.