-
The script below freezes at ~ 98% or ~49643/50000 steps through and I need to kill the pid. However, if I uncomment the following, then it runs successfully: if mx.any(mx.isnan(w_h)):
raise ValueError("Blow-up") Also, interestingly, the iterations / second rapidly decrease at the end, but I'm not paging nor is my memory full. Do I need to compile this in some way or is this a greater issue with MLX? See the loop in
import mlx.core as mx
from tqdm import tqdm
def navier_stokes_2d(w0, f, visc, T, delta_t=1e-4, record_steps=1):
"""Solve 2D Navier-Stokes equation in Fourier space.
Args:
w0 (mx.array): initial vorticity
f (mx.array): forcing term
visc (float): viscosity (1/Re)
T (float): final time
delta_t (float): internal time-step for solve (descrease if blow-up)
record_steps (int): number of in-time snapshots to record
"""
N = w0.shape[-1]
k_max = mx.floor(N / 2.0) # max frequency
steps = int(mx.ceil(T / delta_t))
w_h = mx.fft.fft2(w0)
f_h = mx.fft.fft2(f)
# If same forcing for the whole batch
if len(f_h.shape) < len(w_h.shape):
f_h = mx.expand_dims(f_h, axis=0)
record_time = mx.floor(steps / record_steps)
k_y = mx.tile( # y-direction wavenumber
mx.concat((mx.arange(0, k_max), mx.arange(-k_max, 0)), 0),
(N, 1),
)
k_x = k_y.swapaxes(0, 1) # x-direction wavenumber
# Negative Laplacian in Fourier space
lap = 4 * (mx.pi**2) * (k_x**2 + k_y**2)
lap[0, 0] = 1.0
dealias = mx.logical_and(
mx.abs(k_y) <= (2.0 / 3.0) * k_max, mx.abs(k_x) <= (2.0 / 3.0) * k_max
).astype(mx.float32)
dealias = mx.expand_dims(dealias, axis=0)
sol = mx.zeros((*w0.shape, record_steps))
sol_t = mx.zeros((record_steps))
c, t = 0, 0.0
for j in tqdm(range(steps)):
# Stream function in Fourier space: solve Poisson equation
psi_h = w_h / lap
# Velocity field in x-direction = psi_y
q = mx.real(mx.fft.ifft2(-2j * mx.pi * k_y * psi_h))
# Velocity field in y-direction = -psi_x
v = mx.real(mx.fft.ifft2(2j * mx.pi * k_x * psi_h))
# Partial x of vorticity
w_x = mx.real(mx.fft.ifft2(-2j * mx.pi * k_x * w_h))
# Partial y of vorticity
w_y = mx.real(mx.fft.ifft2(-2j * mx.pi * k_y * w_h))
# Non-linear term (u.grad(w)): compute in physical space then back to Fourier space
F_h = mx.fft.fft2(q * w_x + v * w_y) * dealias
w_h = ( # Cranck-Nicholson update
-delta_t * F_h + delta_t * f_h + (1.0 - 0.5 * delta_t * visc * lap) * w_h
) / (1.0 + 0.5 * delta_t * visc * lap)
#!! WORKS FINE IF THIS IS UNCOMMENTED
# if mx.any(mx.isnan(w_h)):
# raise ValueError("Blow-up")
t += delta_t
if (j + 1) % record_time == 0:
sol[..., c] = mx.real(mx.fft.ifft2(w_h))
sol_t[c] = t
c += 1
return sol, sol_t
def run(N=16, s=256, T=10):
t = mx.linspace(0, 1, s)
X, Y = mx.meshgrid(t, t, indexing="ij")
f = 0.1 * (mx.sin(2 * mx.pi * (X + Y)) + mx.cos(2 * mx.pi * (X + Y)))
w0 = mx.random.normal((N, s, s))
sol, sol_t = navier_stokes_2d(w0, f, 1e-4, T, 1e-4, 200)
return w0, sol, sol_t
if __name__ == "__main__":
a, u, sol_t = run(N=1, s=8, T=5)
print(a.shape, u.shape, sol_t.shape) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
The problem is that you are never evaluating the graph and the graph is very very large. So if you do something like: for j in range(steps):
### do a lot of computation Then you keep appending operations to your compute graph.. but nothing actually triggers the evaluation. In your case In general with an iterative numerical computation you can eval the graph at each iteration and that is usually fine. So just do: mx.eval(sol, sol_t) at the end of your loop. For more on how this all works, check-out the docs and this gist. One more comment, doing if mx.any(mx.isnan(w_h)):
raise ValueError("Blow-up") forces the graph to evaluate since you need to compute the result of I'm closing this as expected behavior, if you think there is a bug or something we are missing, feel free to comment and we can reopen. |
Beta Was this translation helpful? Give feedback.
The problem is that you are never evaluating the graph and the graph is very very large. So if you do something like:
Then you keep appending operations to your compute graph.. but nothing actually triggers the evaluation. In your case
steps=50000
and let's say there are100
ops per step.. that means the graph is 5 million ops (!).In general with an iterative numerical computation you can eval the graph at each iteration and that is usually fine.
So just do:
at the end of your loop.
For more on how this all works, check-out the docs and this gist.
One more comment, doing