From 91decaeaa6e9006ed69f89649f8641a053cfcfc3 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Fri, 30 Sep 2022 16:12:22 -0700 Subject: [PATCH] Fix check that sigma_min and sigma_max > 0 in DPM-Solver --- k_diffusion/sampling.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/k_diffusion/sampling.py b/k_diffusion/sampling.py index 9b301c5b..a9c3333c 100644 --- a/k_diffusion/sampling.py +++ b/k_diffusion/sampling.py @@ -351,8 +351,6 @@ def dpm_solver_fast(self, x, t_start, t_end, nfe, eta=0., s_noise=1.): def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81): if order not in {2, 3}: raise ValueError('order should be 2 or 3') - if t_start == 0 or t_end == 0: - raise ValueError('t_start and t_end should not be 0') forward = t_end > t_start if forward and h_init <= 0: raise ValueError('For forward ODE integration, h_init must be positive') @@ -400,6 +398,8 @@ def dpm_solver_adaptive(self, x, t_start, t_end, order=3, rtol=0.05, atol=0.0078 @torch.no_grad() def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback=None, disable=None, eta=0., s_noise=1.): """DPM-Solver-Fast (fixed step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') with tqdm(total=n, disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: @@ -410,6 +410,8 @@ def sample_dpm_fast(model, x, sigma_min, sigma_max, n, extra_args=None, callback @torch.no_grad() def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callback=None, disable=None, order=3, rtol=0.05, atol=0.0078, h_init=0.05, pcoeff=0., icoeff=1., dcoeff=0., accept_safety=0.81, return_info=False): """DPM-Solver-12 and 23 (adaptive step size). See https://arxiv.org/abs/2206.00927.""" + if sigma_min <= 0 or sigma_max <= 0: + raise ValueError('sigma_min and sigma_max must not be 0') with tqdm(disable=disable) as pbar: dpm_solver = DPMSolver(model, extra_args, eps_callback=pbar.update) if callback is not None: