diff --git a/diffopt/kdescent/descent.py b/diffopt/kdescent/descent.py index 8069e23..6517feb 100644 --- a/diffopt/kdescent/descent.py +++ b/diffopt/kdescent/descent.py @@ -12,7 +12,8 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None, - learning_rate=0.01, randkey=1, const_randkey=False, **other_kwargs): + learning_rate=0.01, randkey=1, const_randkey=False, + thin=1, progress=True, **other_kwargs): """ Perform gradient descent @@ -36,6 +37,11 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None, const_randkey : bool, optional By default (False), randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -46,7 +52,7 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None, if param_bounds is None: return adam_unbounded( lossfunc, guess, nsteps, learning_rate, randkey, - const_randkey, **other_kwargs) + const_randkey, thin, progress, **other_kwargs) assert len(guess) == len(param_bounds) if hasattr(param_bounds, "tolist"): @@ -60,14 +66,15 @@ def ulossfunc(uparams, *args, **kwargs): init_uparams = apply_transforms(guess, param_bounds) uparams = adam_unbounded( ulossfunc, init_uparams, nsteps, learning_rate, randkey, - const_randkey, **other_kwargs) + const_randkey, thin, progress, **other_kwargs) params = apply_inverse_transforms(uparams.T, param_bounds).T return params def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01, - randkey=1, const_randkey=False, **other_kwargs): + randkey=1, const_randkey=False, + thin=1, progress=True, **other_kwargs): kwargs = {**other_kwargs} if randkey is not None: randkey = keygen.init_randkey(randkey) @@ -78,13 +85,18 @@ def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01, opt = optax.adam(learning_rate) solver = jaxopt.OptaxSolver(opt=opt, fun=lossfunc, maxiter=nsteps) state = solver.init_state(guess, **kwargs) - params = [guess] - for _ in tqdm.trange(nsteps, desc="Adam Gradient Descent Progress"): + params = [] + params_i = guess + for i in tqdm.trange(nsteps, disable=not progress, + desc="Adam Gradient Descent Progress"): if randkey is not None: randkey, key_i = jax.random.split(randkey) kwargs["randkey"] = key_i - params_i, state = solver.update(params[-1], state, **kwargs) - params.append(params_i) + params_i, state = solver.update(params_i, state, **kwargs) + if i == nsteps - 1 or (thin and i % thin == thin - 1): + params.append(params_i) + if not thin: + params = params[-1] return jnp.array(params) diff --git a/diffopt/multigrad/adam.py b/diffopt/multigrad/adam.py index 078c47e..d87bc01 100644 --- a/diffopt/multigrad/adam.py +++ b/diffopt/multigrad/adam.py @@ -25,12 +25,12 @@ N_RANKS = 1 -def trange_no_tqdm(n, desc=None): +def trange_no_tqdm(n, desc=None, disable=False): return range(n) -def trange_with_tqdm(n, desc="Adam Gradient Descent Progress"): - return tqdm.trange(n, desc=desc) +def trange_with_tqdm(n, desc="Adam Gradient Descent Progress", disable=False): + return tqdm.trange(n, desc=desc, disable=disable) adam_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm @@ -49,27 +49,32 @@ def _master_wrapper(params, logloss_and_grad_fn, data, randkey=None): return loss, grad -def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None): +def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None, + thin=1, progress=True): kwargs = {} # Note: Might be recommended to use optax instead of jax.example_libraries opt_init, opt_update, get_params = jax_opt.adam(learning_rate) opt_state = opt_init(params) - param_steps = [params] - for step in adam_trange(nsteps): + param_steps = [] + for step in adam_trange(nsteps, disable=not progress): if randkey is not None: randkey, key_i = jax.random.split(randkey) kwargs["randkey"] = key_i _, grad = fn(params, *fn_data, **kwargs) opt_state = opt_update(step, grad, opt_state) params = get_params(opt_state) - param_steps.append(params) + if step == nsteps - 1 or (thin and step % thin == thin - 1): + param_steps.append(params) + if not thin: + param_steps = param_steps[-1] return jnp.array(param_steps) def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100, - learning_rate=0.01, randkey=None): + learning_rate=0.01, randkey=None, + thin=1, progress=True): """Run the adam optimizer on a loss function with a custom gradient. Parameters @@ -88,6 +93,11 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100, randkey : int | PRNG Key If given, a new PRNG Key will be generated at each iteration and be passed to `logloss_and_grad_fn` under the "randkey" kwarg + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -104,7 +114,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100, fn_data = (logloss_and_grad_fn, data) params = _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, - randkey=randkey) + randkey=randkey, thin=thin, progress=progress) if COMM is not None: COMM.bcast("exit", root=0) @@ -131,7 +141,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100, def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None, - learning_rate=0.01, randkey=None): + learning_rate=0.01, randkey=None, thin=1, progress=True): """Run the adam optimizer on a loss function with a custom gradient. Parameters @@ -153,6 +163,11 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None, randkey : int | PRNG Key If given, a new PRNG Key will be generated at each iteration and be passed to `logloss_and_grad_fn` under the "randkey" kwarg + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -162,7 +177,8 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None, if param_bounds is None: return run_adam_unbounded( logloss_and_grad_fn, params, data, nsteps=nsteps, - learning_rate=learning_rate, randkey=randkey) + learning_rate=learning_rate, randkey=randkey, + thin=thin, progress=progress) assert len(params) == len(param_bounds) if hasattr(param_bounds, "tolist"): @@ -182,7 +198,8 @@ def unbound_loss_and_grad(uparams, *args, **kwargs): uparams = apply_trans(params) final_uparams = run_adam_unbounded( - unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey) + unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey, + thin, progress) if RANK == 0: final_params = invert_trans(final_uparams.T).T diff --git a/diffopt/multigrad/bfgs.py b/diffopt/multigrad/bfgs.py index 302875e..93d65eb 100644 --- a/diffopt/multigrad/bfgs.py +++ b/diffopt/multigrad/bfgs.py @@ -18,19 +18,19 @@ N_RANKS = 1 -def trange_no_tqdm(n, desc=None): +def trange_no_tqdm(n, desc=None, disable=False): return range(n) -def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress"): - return tqdm.trange(n, desc=desc, leave=True) +def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress", disable=False): + return tqdm.trange(n, desc=desc, leave=True, disable=disable) bfgs_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None, - randkey=None, comm=COMM): + randkey=None, progress=True, comm=COMM): """Run the adam optimizer on a loss function with a custom gradient. Parameters @@ -46,6 +46,8 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None, `None` as the bound for each unbounded parameter, by default None randkey : int | PRNG Key (default=None) This will be passed to `logloss_and_grad_fn` under the "randkey" kwarg + progress : bool, optional + Display tqdm progress bar, by default True comm : MPI Communicator (default=COMM_WORLD) Communicator between all desired MPI ranks @@ -66,7 +68,7 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None, kwargs["randkey"] = randkey if comm is None or comm.rank == 0: - pbar = bfgs_trange(maxsteps) + pbar = bfgs_trange(maxsteps, disable=not progress) # Wrap loss_and_grad function with commands to the worker ranks def loss_and_grad_fn_root(params): diff --git a/diffopt/multigrad/multigrad.py b/diffopt/multigrad/multigrad.py index b8b374b..2397c87 100644 --- a/diffopt/multigrad/multigrad.py +++ b/diffopt/multigrad/multigrad.py @@ -224,7 +224,8 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None, # NOTE: Never jit this method because it uses mpi4py def run_simple_grad_descent(self: Any, guess, - nsteps=100, learning_rate=0.01): + nsteps=100, learning_rate=0.01, + thin=1, progress=True): """ Descend the gradient with a fixed learning rate to optimize parameters, given an initial guess. Stochasticity not allowed. @@ -237,6 +238,11 @@ def run_simple_grad_descent(self: Any, guess, The number of steps to take. learning_rate : float (default=0.001) The fixed learning rate. + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -253,12 +259,14 @@ def run_simple_grad_descent(self: Any, guess, learning_rate=learning_rate, loss_and_grad_func=self.calc_loss_and_grad_from_params, has_aux=False, + thin=thin, + progress=progress ) # NOTE: Never jit this method because it uses mpi4py def run_adam(self: Any, guess, nsteps=100, param_bounds=None, learning_rate=0.01, randkey=None, const_randkey=False, - comm=None): + thin=1, progress=True, comm=None): """ Run adam to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity is allowed if randkey is passed. @@ -280,6 +288,11 @@ def run_adam(self: Any, guess, nsteps=100, param_bounds=None, const_randkey : bool (default=False) By default, randkey is regenerated at each gradient descent iteration. Remove this behavior by setting const_randkey=True + thin : int, optional + Return parameters for every `thin` iterations, by default 1. Set + `thin=0` to only return final parameters + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -301,14 +314,14 @@ def loss_and_grad_fn(x, _, **kw): params_steps = run_adam( loss_and_grad_fn, params=guess, data=None, nsteps=nsteps, param_bounds=param_bounds, learning_rate=learning_rate, - randkey=randkey + randkey=randkey, thin=thin, progress=progress ) return jnp.asarray(comm.bcast(params_steps, root=0)) # NOTE: Never jit this method because it uses mpi4py def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None, - randkey=None, comm=None): + randkey=None, progress=True, comm=None): """ Run BFGS to descend the gradient and optimize the model parameters, given an initial guess. Stochasticity must be held fixed via a random @@ -327,6 +340,8 @@ def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None, Since BFGS requires a deterministic function, this key will be passed to `calc_loss_and_grad_from_params()` as the "randkey" kwarg as a constant at every iteration + progress : bool, optional + Display tqdm progress bar, by default True Returns ------- @@ -349,7 +364,8 @@ def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None, comm = self.comm if comm is None else comm return run_bfgs( self.calc_loss_and_grad_from_params, guess, maxsteps=maxsteps, - param_bounds=param_bounds, randkey=randkey, comm=comm) + param_bounds=param_bounds, randkey=randkey, + progress=progress, comm=comm) def run_lhs_param_scan(self, xmins, xmaxs, n_dim, num_evaluations, seed=None, randkey=None): @@ -581,22 +597,27 @@ def calc_loss_and_grad_from_params(self, params): # NOTE: Never jit this method because it uses mpi4py def run_simple_grad_descent(self, guess, - nsteps=100, learning_rate=0.01): + nsteps=100, learning_rate=0.01, + thin=1, progress=True): return OnePointModel.run_simple_grad_descent( - self, guess, nsteps, learning_rate) + self, guess, nsteps, learning_rate, thin=thin, progress=progress) # NOTE: Never jit this method because it uses mpi4py - def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None): + def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None, + progress=True): return OnePointModel.run_bfgs( self, guess, maxsteps, param_bounds=param_bounds, - randkey=randkey, comm=self.main_comm) + randkey=randkey, progress=progress, + comm=self.main_comm) # NOTE: Never jit this method because it uses mpi4py def run_adam(self, guess, nsteps=100, param_bounds=None, - learning_rate=0.01, randkey=None, const_randkey=False): + learning_rate=0.01, randkey=None, const_randkey=False, + thin=1, progress=True): return OnePointModel.run_adam( self, guess, nsteps, param_bounds, learning_rate, randkey, - const_randkey=const_randkey, comm=self.main_comm) + const_randkey=const_randkey, thin=thin, progress=progress, + comm=self.main_comm) def __hash__(self): if isinstance(self.models, OnePointModel): diff --git a/diffopt/multigrad/tests/test_mpi.py b/diffopt/multigrad/tests/test_mpi.py index efc0102..6606357 100644 --- a/diffopt/multigrad/tests/test_mpi.py +++ b/diffopt/multigrad/tests/test_mpi.py @@ -58,7 +58,7 @@ def test_simple_grad_descent_pipeline(): gd_loss, gd_params = gd_iterations.loss, gd_iterations.params assert jnp.isclose(gd_loss[-1], 0.0) assert jnp.allclose(gd_params[-1], jnp.array([*truth])) - assert jnp.allclose(true_gradloss, 0.0, atol=1e-5) + assert jnp.allclose(true_gradloss, 0.0, atol=1e-4) # Calculate grad(loss) with the more memory efficient method loss, dloss_dparams = model.calc_loss_and_grad_from_params(truth) diff --git a/diffopt/multigrad/util.py b/diffopt/multigrad/util.py index c39a161..9d4cf4c 100644 --- a/diffopt/multigrad/util.py +++ b/diffopt/multigrad/util.py @@ -36,12 +36,12 @@ "latin_hypercube_sampler", "scatter_nd"] -def trange_no_tqdm(n, desc=None): +def trange_no_tqdm(n, desc=None, disable=False): return range(n) -def trange_with_tqdm(n, desc=None): - return tqdm.trange(n, desc=desc) +def trange_with_tqdm(n, desc=None, disable=False): + return tqdm.trange(n, desc=desc, disable=disable) trange = trange_no_tqdm if tqdm is None else trange_with_tqdm @@ -85,6 +85,8 @@ def simple_grad_descent( loss_and_grad_func=None, grad_loss_func=None, has_aux=False, + thin=1, + progress=True, **kwargs, ): if loss_and_grad_func is None: @@ -115,13 +117,19 @@ def loopfunc(state, _x): # The below is equivalent to lax.scan without jitting # =================================================== - initstate = (0.0, guess) + state = (0.0, guess) loss, params, aux = [], [], [] - for x in trange(nsteps, desc="Simple Gradient Descent Progress"): - initstate, y = loopfunc(initstate, x) - loss.append(y[0]) - params.append(y[1]) - aux.append(y[2]) + for x in trange(nsteps, desc="Simple Gradient Descent Progress", + disable=not progress): + state, y = loopfunc(state, x) + if x == nsteps - 1 or (thin and x % thin == thin - 1): + loss.append(y[0]) + params.append(y[1]) + aux.append(y[2]) + if not thin: + loss = loss[-1] + params = params[-1] + aux = aux[-1] loss = jnp.array(loss) params = jnp.array(params) if has_aux: diff --git a/diffopt/multiswarm/pso_update.py b/diffopt/multiswarm/pso_update.py index f08f491..0db1588 100644 --- a/diffopt/multiswarm/pso_update.py +++ b/diffopt/multiswarm/pso_update.py @@ -95,7 +95,8 @@ def __init__(self, nparticles, ndim, xlow, xhigh, seed=0, self.social_weight = social_weight self.vmax_frac = vmax_frac - def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False): + def run_pso(self, lossfunc, nsteps=100, progress=True, + keep_init_random_state=False): """ Run particle swarm optimization (PSO) @@ -106,6 +107,8 @@ def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False): with signature `lossfunc(x)` where x is an array of shape `(ndim,)` nsteps : int, optional Number of time step iterations, by default 100 + progress : bool, optional + Display tqdm progress bar, by default True keep_init_random_state : bool, optional Set True to be able to rerun an identical run, or False (default) to continue a run by manually setting swarm.x_init and swarm.v_init @@ -140,12 +143,12 @@ def run_pso(self, lossfunc, nsteps=100, keep_init_random_state=False): loc_loss_history = [[] for _ in range(self.num_particles_on_this_rank)] start = time() - def trange(x): + def trange(x, disable=False): if self.comm.rank: return range(x) else: - return tqdm.trange(x, desc="PSO Progress") - for _ in trange(nsteps): + return tqdm.trange(x, desc="PSO Progress", disable=disable) + for _ in trange(nsteps, disable=not progress): istep_loss = [None for _ in range(self.num_particles_on_this_rank)] for ip in range(self.num_particles_on_this_rank): update_key = jran.split(particle_keys[ip], 1)[0]