Python JAX/Numpy implementation of EGSnrc #658
Replies: 14 comments 64 replies
-
Note that in the C++ code above, using plain arrays (as in |
Beta Was this translation helpful? Give feedback.
-
When it comes time for testing I have a cluster we can test with. 6 compute nodes and a master node, currently with OpenPBS, next incarnation will be Torque. I'm thinking it might be cleanest to create a new EGSnrc - Testing repository with folders for each language. Root:
We could start with this today. Now that I think about it I can see adding some testing to this infrastructure right now would be a big help as I continue to test Clang/Flang vs GCC versions. It might make it easier for others to share their testing tools too. I think it would make sense to include performance tests like the one above to keep track of tooling changes. |
Beta Was this translation helpful? Give feedback.
-
This is quite exciting! 🙂 🎉 @ftessier did you want me to still build out a lookup table based adjustment approach in JAX to test out its timing? I'd be happy to do that. Does EGSnrc interpolate between the table entries? If so, is this just done bi-linearly? (Or n-linearly depending how many dimensions the table has...) |
Beta Was this translation helpful? Give feedback.
-
So indeed, when I modify the python code to do what the C++ code does, using NumPy (no JAX) but updating particles one at a time (yet still generating all random numbers at "once" with |
Beta Was this translation helpful? Give feedback.
-
To be fair to the C++ code, I should have generated the random numbers outside the loop, as in: std::vector<double> random_vector(7*NUM_PARTICLES);
std::generate(begin(random_vector), end(random_vector), bind(sample, generator));
// update particle arrays
for (int i=0; i<ITERATIONS; i++) {
int k = 0;
for (int n=0; n<NUM_PARTICLES; n++) {
x[n] += random_vector[k++];
y[n] += random_vector[k++];
z[n] += random_vector[k++];
u[n] += random_vector[k++];
v[n] += random_vector[k++];
w[n] += random_vector[k++];
E[n] += random_vector[k++];
}
} which further improves the performance about threefold (then one could also write vector operations for the update; not sure what the optimizer does with it):
|
Beta Was this translation helpful? Give feedback.
-
So, I've made a Particles Dictionary that is amenable to JAX jitting. On Google's colaboratory the timings are much the same as before: # CPU:
# random_walk duration: 5368.768 ms
# random_walk duration: 1387.357 ms
# random_walk duration: 1339.349 ms
# random_walk duration: 1359.297 ms
# random_walk duration: 1344.687 ms
# random_walk duration: 1326.127 ms
# random_walk duration: 1358.951 ms
# random_walk duration: 1382.427 ms
# random_walk duration: 1355.318 ms
# random_walk duration: 1377.143 ms
# GPU:
# random_walk duration: 1373.592 ms
# random_walk duration: 15.000 ms
# random_walk duration: 14.722 ms
# random_walk duration: 14.642 ms
# random_walk duration: 21.203 ms
# random_walk duration: 14.714 ms
# random_walk duration: 14.782 ms
# random_walk duration: 14.749 ms
# random_walk duration: 14.655 ms
# random_walk duration: 14.642 ms This is using the same timing approach as @ftessier used at the beginning of this discussion for consistency. I've copied in the script for the timing tests below, but it is also available at https://github.com/SimonBiggs/egsnrc2py/blob/787eef7e28dbad2d61ef1da90343413763295259/prototyping/mypy_based_particles.py#L1-L82 import time
from typing import Dict, Tuple
from typing_extensions import Literal
import matplotlib.pyplot as plt
from jax import jit, random
import jax.numpy as jnp
ParticleKeys = Literal["position", "direction", "energy"]
Particles = Dict[ParticleKeys, jnp.DeviceArray]
def random_walk(
prng_key: jnp.DeviceArray, particles: Particles, iterations: int,
) -> Tuple[jnp.DeviceArray, Particles]:
num_particles = particles["position"].shape[-1]
for _ in range(iterations):
random_normal_numbers = random.normal(prng_key, shape=(7, num_particles))
(prng_key,) = random.split(prng_key, 1)
particles["position"] += random_normal_numbers[0:3, :]
particles["direction"] += random_normal_numbers[3:6, :]
particles["energy"] += random_normal_numbers[7, :]
return prng_key, particles
random_walk = jit(random_walk, static_argnums=(2,))
def timer(func):
def wrap(*args, **kwargs):
start = time.time()
ret = func(*args, **kwargs)
# See https://jax.readthedocs.io/en/latest/async_dispatch.html
# for why this is needed.
_, particles = ret
for _, item in particles.items():
item.block_until_ready()
stop = time.time()
duration = (stop - start) * 1000.0
print("{:s} duration: {:.3f} ms".format(func.__name__, duration))
return ret
return wrap
random_walk = timer(random_walk)
def particles_zeros(num_particles: int) -> Particles:
particles: Particles = {
"position": jnp.zeros((3, num_particles)),
"direction": jnp.zeros((3, num_particles)),
"energy": jnp.zeros((1, num_particles)),
}
return particles
def main():
seed = 0
prng_key = random.PRNGKey(seed)
num_particles = int(1e6)
iterations = 10
runs = 10
particles = particles_zeros(num_particles)
for _ in range(runs):
prng_key, particles = random_walk(prng_key, particles, iterations)
plt.scatter(particles["position"][0, 0:1000], particles["position"][1, 0:1000])
plt.show()
if __name__ == "__main__":
main() The resulting plot looks like: @ftessier and @darcymason does this approach seem to appropriately address the following concern:
|
Beta Was this translation helpful? Give feedback.
-
Addressing this next. |
Beta Was this translation helpful? Give feedback.
-
Accessing the data files is perhaps not the next logical step. In light of the discussion regarding vectorization, it makes more sense to me now to focus on the vectorization logic. Let's consider only photons for clarity, since for electrons a number of complications aris if one uses multiple-scattering. Say you have a list of photons, in an infinite medium, and 4 interaction (Rayleigh, photoelectric, Compton, pair production). We need to show significant efficiency gains with JIT for the following sequence:
|
Beta Was this translation helpful? Give feedback.
-
Soon we'll need some better profiling tools: even for the toy codes up to now, I find large variances in wall clock time due to system load. For python scripts there is the cProfile module, which can produce profiling data, and qcachegrind to display results visually. I wonder if this can profile JITted code?
|
Beta Was this translation helpful? Give feedback.
-
@SimonBiggs do you know how this can work on a local computer, i.e., not on colab. JAX (XLA) supports only CUDA, correct? |
Beta Was this translation helpful? Give feedback.
-
Just a heads up @ftessier and @darcymason, All transpilation prototyping work will be undergone within the https://github.com/darcymason/egsnrc2py repo. All JAX vectorisation prototyping will by undergone within the https://github.com/SimonBiggs/pyegsnrc repo. Probably easier to discuss each component within corresponding issues within each repo. |
Beta Was this translation helpful? Give feedback.
-
Julia can also be an option if you don't want to deal with vectorisation. |
Beta Was this translation helpful? Give feedback.
-
|
Beta Was this translation helpful? Give feedback.
-
Still very much a work in progress... but here are some results. First, a side-note: the new REPLACE call still didn't work. I added an The really good news is that I've compared counts for Compton/Photo (both initial and 'indirect' interactions from scattered photons) for 10M particles (1 MeV to avoid pair/triplet for now) in my 'thin two-slab" geometry (two different materials) and the counts by region are statistically the same for the Python code and the Mortran code. So this is now more than just a toy example, but moving closer to a first-draft of photon-only MC. The speed is very hard to capture in the GPU code - I'm just using the free shared-resources Colab and the times can vary very widely - e.g. the same kernel code can run in ~60 ms or ~3500 ms. If I do a run with 5M photons and then 10M, and take the difference, it is in the range of 50 - 80 ms (taking the difference also removes the jit compile time). Using the lowest time as probably the closest to reality, I'm getting about factor 15 faster than the single-CPU Mortran code on my laptop. However, I don't know how much faster these could be in a dedicated GPU. The Python times, btw, just include the "kernel" time - so not the setup, transfer of data to the device or back, and not the summing of interactions counts on CPU afterwards. But those total ~ 30 sec or less. I suspect access to slow memory is holding the GPU back. I've started playing around with changing the memory to accumulate counts in faster GPU shared block memory (not quite working yet). The results above were from one global 'score' array with entries storing interaction counts for each thread individually, meaning a lot of access (10M threads * 4 regions * 4 interaction types) to the slowest memory. I intend to keep playing with the shared memory scoring, and once that is worked out, try pair/triplet and confirm it is also agreeing with the Mortran code. |
Beta Was this translation helpful? Give feedback.
-
A discussion in the pymedphys repository has recently touched upon the topic of rewriting EGSnrc in a modern language. @SimonBiggs mentioned the possibility to use JAX/Numpy to run simulations, with "native" support for GPU compilation. This is at least worth a good look, building toy models to study performance. @SimonBiggs has shown that the out-of-the-box GPU compilation is a couple orders of magnitude faster than straight CPU runs (to simply update "particle" arrays). The CPU runs with Numpy are nearly as fast as C++.
For the record, here are the python code and timing results (relying on simple wall clock time, careful!) to update (r, u, E) for 1e6 particles:
And here are the equivalent C++ code, with the same random number generator (and equally poor timer!):
Please comment further regarding the outlook for implementing EGSnrc in python. We can't expect efficiency to match optimized C++, but if we get GPU (and TPU!) compilation for "free" and gain in code clarity with python, my opinion is that it would be worth a shot.
Beta Was this translation helpful? Give feedback.
All reactions