Skip to content

Commit

Permalink
Add thin and progress kwargs to fitters
Browse files Browse the repository at this point in the history
  • Loading branch information
AlanPearl committed Nov 7, 2024
1 parent 94d306a commit 62d8db1
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 50 deletions.
28 changes: 20 additions & 8 deletions diffopt/kdescent/descent.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@


def adam(lossfunc, guess, nsteps=100, param_bounds=None,
learning_rate=0.01, randkey=1, const_randkey=False, **other_kwargs):
learning_rate=0.01, randkey=1, const_randkey=False,
thin=1, progress=True, **other_kwargs):
"""
Perform gradient descent
Expand All @@ -36,6 +37,11 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None,
const_randkey : bool, optional
By default (False), randkey is regenerated at each gradient descent
iteration. Remove this behavior by setting const_randkey=True
thin : int, optional
Return parameters for every `thin` iterations, by default 1. Set
`thin=0` to only return final parameters
progress : bool, optional
Display tqdm progress bar, by default True
Returns
-------
Expand All @@ -46,7 +52,7 @@ def adam(lossfunc, guess, nsteps=100, param_bounds=None,
if param_bounds is None:
return adam_unbounded(
lossfunc, guess, nsteps, learning_rate, randkey,
const_randkey, **other_kwargs)
const_randkey, thin, progress, **other_kwargs)

assert len(guess) == len(param_bounds)
if hasattr(param_bounds, "tolist"):
Expand All @@ -60,14 +66,15 @@ def ulossfunc(uparams, *args, **kwargs):
init_uparams = apply_transforms(guess, param_bounds)
uparams = adam_unbounded(
ulossfunc, init_uparams, nsteps, learning_rate, randkey,
const_randkey, **other_kwargs)
const_randkey, thin, progress, **other_kwargs)
params = apply_inverse_transforms(uparams.T, param_bounds).T

return params


def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01,
randkey=1, const_randkey=False, **other_kwargs):
randkey=1, const_randkey=False,
thin=1, progress=True, **other_kwargs):
kwargs = {**other_kwargs}
if randkey is not None:
randkey = keygen.init_randkey(randkey)
Expand All @@ -78,13 +85,18 @@ def adam_unbounded(lossfunc, guess, nsteps=100, learning_rate=0.01,
opt = optax.adam(learning_rate)
solver = jaxopt.OptaxSolver(opt=opt, fun=lossfunc, maxiter=nsteps)
state = solver.init_state(guess, **kwargs)
params = [guess]
for _ in tqdm.trange(nsteps, desc="Adam Gradient Descent Progress"):
params = []
params_i = guess
for i in tqdm.trange(nsteps, disable=not progress,
desc="Adam Gradient Descent Progress"):
if randkey is not None:
randkey, key_i = jax.random.split(randkey)
kwargs["randkey"] = key_i
params_i, state = solver.update(params[-1], state, **kwargs)
params.append(params_i)
params_i, state = solver.update(params_i, state, **kwargs)
if i == nsteps - 1 or (thin and i % thin == thin - 1):
params.append(params_i)
if not thin:
params = params[-1]

return jnp.array(params)

Expand Down
41 changes: 29 additions & 12 deletions diffopt/multigrad/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,12 @@
N_RANKS = 1


def trange_no_tqdm(n, desc=None):
def trange_no_tqdm(n, desc=None, disable=False):
return range(n)


def trange_with_tqdm(n, desc="Adam Gradient Descent Progress"):
return tqdm.trange(n, desc=desc)
def trange_with_tqdm(n, desc="Adam Gradient Descent Progress", disable=False):
return tqdm.trange(n, desc=desc, disable=disable)


adam_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm
Expand All @@ -49,27 +49,32 @@ def _master_wrapper(params, logloss_and_grad_fn, data, randkey=None):
return loss, grad


def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None):
def _adam_optimizer(params, fn, fn_data, nsteps, learning_rate, randkey=None,
thin=1, progress=True):
kwargs = {}
# Note: Might be recommended to use optax instead of jax.example_libraries
opt_init, opt_update, get_params = jax_opt.adam(learning_rate)
opt_state = opt_init(params)

param_steps = [params]
for step in adam_trange(nsteps):
param_steps = []
for step in adam_trange(nsteps, disable=not progress):
if randkey is not None:
randkey, key_i = jax.random.split(randkey)
kwargs["randkey"] = key_i
_, grad = fn(params, *fn_data, **kwargs)
opt_state = opt_update(step, grad, opt_state)
params = get_params(opt_state)
param_steps.append(params)
if step == nsteps - 1 or (thin and step % thin == thin - 1):
param_steps.append(params)
if not thin:
param_steps = param_steps[-1]

return jnp.array(param_steps)


def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
learning_rate=0.01, randkey=None):
learning_rate=0.01, randkey=None,
thin=1, progress=True):
"""Run the adam optimizer on a loss function with a custom gradient.
Parameters
Expand All @@ -88,6 +93,11 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
randkey : int | PRNG Key
If given, a new PRNG Key will be generated at each iteration and be
passed to `logloss_and_grad_fn` under the "randkey" kwarg
thin : int, optional
Return parameters for every `thin` iterations, by default 1. Set
`thin=0` to only return final parameters
progress : bool, optional
Display tqdm progress bar, by default True
Returns
-------
Expand All @@ -104,7 +114,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,
fn_data = (logloss_and_grad_fn, data)

params = _adam_optimizer(params, fn, fn_data, nsteps, learning_rate,
randkey=randkey)
randkey=randkey, thin=thin, progress=progress)

if COMM is not None:
COMM.bcast("exit", root=0)
Expand All @@ -131,7 +141,7 @@ def run_adam_unbounded(logloss_and_grad_fn, params, data, nsteps=100,


def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
learning_rate=0.01, randkey=None):
learning_rate=0.01, randkey=None, thin=1, progress=True):
"""Run the adam optimizer on a loss function with a custom gradient.
Parameters
Expand All @@ -153,6 +163,11 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
randkey : int | PRNG Key
If given, a new PRNG Key will be generated at each iteration and be
passed to `logloss_and_grad_fn` under the "randkey" kwarg
thin : int, optional
Return parameters for every `thin` iterations, by default 1. Set
`thin=0` to only return final parameters
progress : bool, optional
Display tqdm progress bar, by default True
Returns
-------
Expand All @@ -162,7 +177,8 @@ def run_adam(logloss_and_grad_fn, params, data, nsteps=100, param_bounds=None,
if param_bounds is None:
return run_adam_unbounded(
logloss_and_grad_fn, params, data, nsteps=nsteps,
learning_rate=learning_rate, randkey=randkey)
learning_rate=learning_rate, randkey=randkey,
thin=thin, progress=progress)

assert len(params) == len(param_bounds)
if hasattr(param_bounds, "tolist"):
Expand All @@ -182,7 +198,8 @@ def unbound_loss_and_grad(uparams, *args, **kwargs):

uparams = apply_trans(params)
final_uparams = run_adam_unbounded(
unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey)
unbound_loss_and_grad, uparams, data, nsteps, learning_rate, randkey,
thin, progress)

if RANK == 0:
final_params = invert_trans(final_uparams.T).T
Expand Down
12 changes: 7 additions & 5 deletions diffopt/multigrad/bfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
N_RANKS = 1


def trange_no_tqdm(n, desc=None):
def trange_no_tqdm(n, desc=None, disable=False):
return range(n)


def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress"):
return tqdm.trange(n, desc=desc, leave=True)
def trange_with_tqdm(n, desc="BFGS Gradient Descent Progress", disable=False):
return tqdm.trange(n, desc=desc, leave=True, disable=disable)


bfgs_trange = trange_no_tqdm if tqdm is None else trange_with_tqdm


def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
randkey=None, comm=COMM):
randkey=None, progress=True, comm=COMM):
"""Run the adam optimizer on a loss function with a custom gradient.
Parameters
Expand All @@ -46,6 +46,8 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
`None` as the bound for each unbounded parameter, by default None
randkey : int | PRNG Key (default=None)
This will be passed to `logloss_and_grad_fn` under the "randkey" kwarg
progress : bool, optional
Display tqdm progress bar, by default True
comm : MPI Communicator (default=COMM_WORLD)
Communicator between all desired MPI ranks
Expand All @@ -66,7 +68,7 @@ def run_bfgs(loss_and_grad_fn, params, maxsteps=100, param_bounds=None,
kwargs["randkey"] = randkey

if comm is None or comm.rank == 0:
pbar = bfgs_trange(maxsteps)
pbar = bfgs_trange(maxsteps, disable=not progress)

# Wrap loss_and_grad function with commands to the worker ranks
def loss_and_grad_fn_root(params):
Expand Down
43 changes: 32 additions & 11 deletions diffopt/multigrad/multigrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,8 @@ def calc_loss_from_sumstats(self, sumstats, sumstats_aux=None,

# NOTE: Never jit this method because it uses mpi4py
def run_simple_grad_descent(self: Any, guess,
nsteps=100, learning_rate=0.01):
nsteps=100, learning_rate=0.01,
thin=1, progress=True):
"""
Descend the gradient with a fixed learning rate to optimize parameters,
given an initial guess. Stochasticity not allowed.
Expand All @@ -237,6 +238,11 @@ def run_simple_grad_descent(self: Any, guess,
The number of steps to take.
learning_rate : float (default=0.001)
The fixed learning rate.
thin : int, optional
Return parameters for every `thin` iterations, by default 1. Set
`thin=0` to only return final parameters
progress : bool, optional
Display tqdm progress bar, by default True
Returns
-------
Expand All @@ -253,12 +259,14 @@ def run_simple_grad_descent(self: Any, guess,
learning_rate=learning_rate,
loss_and_grad_func=self.calc_loss_and_grad_from_params,
has_aux=False,
thin=thin,
progress=progress
)

# NOTE: Never jit this method because it uses mpi4py
def run_adam(self: Any, guess, nsteps=100, param_bounds=None,
learning_rate=0.01, randkey=None, const_randkey=False,
comm=None):
thin=1, progress=True, comm=None):
"""
Run adam to descend the gradient and optimize the model parameters,
given an initial guess. Stochasticity is allowed if randkey is passed.
Expand All @@ -280,6 +288,11 @@ def run_adam(self: Any, guess, nsteps=100, param_bounds=None,
const_randkey : bool (default=False)
By default, randkey is regenerated at each gradient descent
iteration. Remove this behavior by setting const_randkey=True
thin : int, optional
Return parameters for every `thin` iterations, by default 1. Set
`thin=0` to only return final parameters
progress : bool, optional
Display tqdm progress bar, by default True
Returns
-------
Expand All @@ -301,14 +314,14 @@ def loss_and_grad_fn(x, _, **kw):
params_steps = run_adam(
loss_and_grad_fn, params=guess, data=None, nsteps=nsteps,
param_bounds=param_bounds, learning_rate=learning_rate,
randkey=randkey
randkey=randkey, thin=thin, progress=progress
)

return jnp.asarray(comm.bcast(params_steps, root=0))

# NOTE: Never jit this method because it uses mpi4py
def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None,
randkey=None, comm=None):
randkey=None, progress=True, comm=None):
"""
Run BFGS to descend the gradient and optimize the model parameters,
given an initial guess. Stochasticity must be held fixed via a random
Expand All @@ -327,6 +340,8 @@ def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None,
Since BFGS requires a deterministic function, this key will be
passed to `calc_loss_and_grad_from_params()` as the "randkey" kwarg
as a constant at every iteration
progress : bool, optional
Display tqdm progress bar, by default True
Returns
-------
Expand All @@ -349,7 +364,8 @@ def run_bfgs(self: Any, guess, maxsteps=100, param_bounds=None,
comm = self.comm if comm is None else comm
return run_bfgs(
self.calc_loss_and_grad_from_params, guess, maxsteps=maxsteps,
param_bounds=param_bounds, randkey=randkey, comm=comm)
param_bounds=param_bounds, randkey=randkey,
progress=progress, comm=comm)

def run_lhs_param_scan(self, xmins, xmaxs, n_dim,
num_evaluations, seed=None, randkey=None):
Expand Down Expand Up @@ -581,22 +597,27 @@ def calc_loss_and_grad_from_params(self, params):

# NOTE: Never jit this method because it uses mpi4py
def run_simple_grad_descent(self, guess,
nsteps=100, learning_rate=0.01):
nsteps=100, learning_rate=0.01,
thin=1, progress=True):
return OnePointModel.run_simple_grad_descent(
self, guess, nsteps, learning_rate)
self, guess, nsteps, learning_rate, thin=thin, progress=progress)

# NOTE: Never jit this method because it uses mpi4py
def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None):
def run_bfgs(self, guess, maxsteps=100, param_bounds=None, randkey=None,
progress=True):
return OnePointModel.run_bfgs(
self, guess, maxsteps, param_bounds=param_bounds,
randkey=randkey, comm=self.main_comm)
randkey=randkey, progress=progress,
comm=self.main_comm)

# NOTE: Never jit this method because it uses mpi4py
def run_adam(self, guess, nsteps=100, param_bounds=None,
learning_rate=0.01, randkey=None, const_randkey=False):
learning_rate=0.01, randkey=None, const_randkey=False,
thin=1, progress=True):
return OnePointModel.run_adam(
self, guess, nsteps, param_bounds, learning_rate, randkey,
const_randkey=const_randkey, comm=self.main_comm)
const_randkey=const_randkey, thin=thin, progress=progress,
comm=self.main_comm)

def __hash__(self):
if isinstance(self.models, OnePointModel):
Expand Down
2 changes: 1 addition & 1 deletion diffopt/multigrad/tests/test_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_simple_grad_descent_pipeline():
gd_loss, gd_params = gd_iterations.loss, gd_iterations.params
assert jnp.isclose(gd_loss[-1], 0.0)
assert jnp.allclose(gd_params[-1], jnp.array([*truth]))
assert jnp.allclose(true_gradloss, 0.0, atol=1e-5)
assert jnp.allclose(true_gradloss, 0.0, atol=1e-4)

# Calculate grad(loss) with the more memory efficient method
loss, dloss_dparams = model.calc_loss_and_grad_from_params(truth)
Expand Down
Loading

0 comments on commit 62d8db1

Please sign in to comment.