Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

passing current time to stepper for time dependent stimulation #76

Open
Ziaeemehr opened this issue Apr 13, 2024 · 5 comments
Open

passing current time to stepper for time dependent stimulation #76

Ziaeemehr opened this issue Apr 13, 2024 · 5 comments

Comments

@Ziaeemehr
Copy link
Collaborator

There are some conditions that dfun need to apply a time dependent process on the system of equations, for example applying time dependent stimulation with start and end time period, can we do this now or we need to pass time to stepper?

something like this in loops.py here

def heun_step(x, dfun, dt, t, *args, add=zero, adhoc=None, return_euler=False):

cheers,
A

@maedoc
Copy link
Member

maedoc commented Apr 15, 2024

Thanks for opening this issue, this has been requested also by @mbreyt & @i-Zaak so let's step through some considerations in the implementation:

There are a few ways to do it; one is to go fully non-autonomous, so that the differential equations take time as an argument,

def dfun(t: float, x: ndarray, p) -> ndarray:
   stim = np.sin(t * 2*np.pi * p.stim_freq)
   pass

another is to have time as an extra state variable, with a derivative of 1,

def dfun(x, p):
    t, ... = x
    stim = np.sin(t * 2*np.pi * p.stim_freq)
    return np.array([1, ...]

but in the case where the time dependent process has already been evaluated and is just a an array in memory of shape (n_time, ...), it might not make sense to interpolate the process in time, rather ensure that the time stepping occurs on the same grid as the time dep process itself (or some suitable ratio dt), then Jax would iterate this as an extra scan arg,

def dfun(x, p, stim):
    pass

def op(c, xs):
    t, stim = xs
    return heun(x, dfun, dt, stim)

ts = np.r_[:n_time]
stim = np.sin(ts * 2*np.pi * p.stim_freq)
_, xs = jax.lax.scan(op, x, (ts, stim))

unfortunately, function signatures are not extensible (do you enjoy trying to figure out what's going on when you send data through **kwargs? I usually don't), so this gets a bit complex over time. One way to work around this is to employ a more flexible data structure which is then used in the dfun,

sim = {
    'time': 0.0,
    'stim': np.zeros((n_time, n_node), ),
    'x': x_init
    'param': ...
}

sim['stim'] = ...

ode_step, _ = vb.make_ode(...)

def sim_step(sim, t):
    sim['time'] = t
    sim['stim'] = np.sin(t*2*pi*sim['param'].stim_freq)
    sim['x'] = ode_step(sim['x'], sim['param'])
    return sim, sim['x']
   
sim, xs = jax.lax.scan(sim_step, sim, np.r_[:n_time])

This allows more customization of the details of the simulation without increasing the argument list for many functions like dfun, heun and make_ode, but retains benefits of Jax like JIT parallelization for GPU & gradients:

def run_one_freq(freq):
    sim['param'].stim_freq = freq
    sim, xs = jax.lax.scan(sim_step, sim, np.r_[:n_time])
    return xs

run_sweep = jax.jit(jax.vmap(run_one_freq))
run_grad = jax.jit(jax.grad(run_one_freq))

We would want to use actual dataclasses instead of dict to keep it easier to use but that's the idea.

Can you @Ziaeemehr @i-Zaak @mbreyt comment on these ideas?

@Ziaeemehr
Copy link
Collaborator Author

The idea with a dictionary seems more customizable, I haven't worked with dataclasses, so if that's easier to use why not?
Extra state variable also seems good as a temporal solution.
Thanks for the explanation.

@i-Zaak
Copy link

i-Zaak commented Apr 15, 2024

I'm not too keen on the time-variable parameters to be honest, and I think we also don't need fake state variables. There are two main semantical cases for the stimulus as I see it: a) the autonomous part can be separated from the rest of the model, and the stim acts as an external forcing on selected state variables (vanilla TVB), and b) the autonomous part cannot be separated as for example when the stimulus acts as external input to some of the populations (e.g. AdEx).

The case a) can be handled I think with autonomous version of the make_XXde methods. The case b) I think can be in current architecture handled as additional coupling term, and evaluated in the *_net_dfun functions. I both cases, I can imagine both the stimulus prescribed as tabular or analytic functions to cover the precomputed or prescribed stimuli.

@i-Zaak
Copy link

i-Zaak commented Apr 20, 2024

Few showers later, here are some more thoughts.

I'm starting to like the integrated time state variable more, as I think it can be approached the same way as the multi-NMM simulations.

Making the timer an explicit "domain" I think helps to reason about this.

def timer_dfun(t):
	return 1

def stim(t, p): # just an auxillary function, not integrated
    stim = np.sin(t * 2*np.pi * p.stim_freq)
    return stim
    
def dfun(tx,  p):
    t, x = tx # we would do this also for the franken-dfun
    st = stim(t, p.stim) # param name-spaces to avoid collisions?
   
    # reshape stim to x, apply differential weights based on nodes, stvars etc.
   
    xd = nmm_dfun(x, [c, st], p.nmm) # or c+stim, depends on the nmm...
    td = timer_dfun(t, p)
    
    return np.array([td, xd]) # shouldn't this be a list?

Alternatively, when the stimulus is additive to the nmm dfun, one would do something like this

def dfun(tx,  p):
    t, x = tx # we would do this also for the franken-dfun
    st = stim(t, p.stim) 
    
    xd = nmm_dfun(x, c, p.nmm)
    td = timer_dfun(t, p)
    
    return np.array([td, xd + st]) 

Also the case when stimulus is precomputed can be implemented as a closure of stim over dt with rounded division inside, that I think is not an issue.

Thoughts? :)

@maedoc
Copy link
Member

maedoc commented Jun 13, 2024

Thanks for unpacking that idea a bit. Since there's endless things that could go into this, and it's best to keep function args simple and unchanging for the API, I think I'd want to lean towards a dataclass approach

class State:
    step: int
    t: float  # just step*dt
    x: states
    ...

def dfun(state: State, p: Param) -> Deriv:
    ...

which is a little more obvious than tuples of tuples, and jax.jit makes this free.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants