Skip to content

Commit

Permalink
[BugFix][xrt] Protect against OoB read/writes.
Browse files Browse the repository at this point in the history
  • Loading branch information
SepandKashani committed Dec 1, 2023
1 parent eb2b88e commit 41d8aa1
Showing 1 changed file with 12 additions and 31 deletions.
43 changes: 12 additions & 31 deletions src/pyxu/experimental/xray/_rt.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,8 @@ def ray_step(r: Ray3f) -> Ray3f:
)

# Compute step size to closest bounding box wall.
# (a1, a2) may contain negative values or Infs.
# In any case, we must always choose min(a1, a2) > 0.
# (a1, a2) may contain negative values or Infs.
# In any case, we must always choose min(a1, a2) > 0.
_, a1, a2 = bbox.ray_intersect(r_local)
a_min = dr.minimum(a1, a2)
a_max = dr.maximum(a1, a2)
Expand Down Expand Up @@ -670,11 +670,9 @@ def xrt_apply(
loop = drb.Loop("XRT FW", lambda: (r, r_next, active, P))
while loop(active):
# Read (I,) at current cell
# Careful to disable out-of-bound queries.
# [This may occur if FP-error caused r_next(above) to not enter the lattice;
# auto-rectified at next iteration.]
# Careful to disable out-of-bound queries. (Due to FP-errors.)
idx_I = dr.floor(0.5 * (r_next.o + r.o))
mask = active & dr.all(idx_I >= 0)
mask = active & dr.all(0 <= idx_I) & dr.all(idx_I < N)
weight = dr.gather(type(I), I, flat_index(idx_I), mask)

# Compute constants
Expand Down Expand Up @@ -733,11 +731,9 @@ def wxrt_apply(
loop = drb.Loop("WXRT FW", lambda: (r, r_next, active, P, d_acc))
while loop(active):
# Read (I, w) at current cell
# Careful to disable out-of-bound queries.
# [This may occur if FP-error caused r_next(above) to not enter the lattice;
# auto-rectified at next iteration.]
# Careful to disable out-of-bound queries. (Due to FP-errors.)
idx_I = dr.floor(0.5 * (r_next.o + r.o))
mask = active & dr.all(idx_I >= 0)
mask = active & dr.all(0 <= idx_I) & dr.all(idx_I < N)
weight = dr.gather(type(I), I, flat_index(idx_I), mask)
decay = dr.gather(type(w), w, flat_index(idx_I), mask)

Expand Down Expand Up @@ -800,22 +796,15 @@ def xrt_adjoint(
active &= dr.neq(P, 0)
loop = drb.Loop("XRT BW", lambda: (r, r_next, active))
while loop(active):
# Careful to disable out-of-bound queries. (Due to FP-errors.)
idx_I = dr.floor(0.5 * (r_next.o + r.o))
mask = active & dr.all(0 <= idx_I) & dr.all(idx_I < N)

# Compute constants
length = dr.norm((r_next.o - r.o) * pitch)

# Update back-projections
dr.scatter_reduce(
dr.ReduceOp.Add,
I,
P * length,
flat_index(idx_I),
active & dr.all(idx_I >= 0),
# Careful to disable out-of-bound queries.
# [This may occur if FP-error caused r_next(above) to not enter the lattice;
# auto-rectified at next iteration.]
)
dr.scatter_reduce(dr.ReduceOp.Add, I, P * length, flat_index(idx_I), mask)

# Walk to next lattice intersection
r.assign(r_next)
Expand Down Expand Up @@ -868,11 +857,9 @@ def wxrt_adjoint(
loop = drb.Loop("WXRT BW", lambda: (r, r_next, active, d_acc))
while loop(active):
# Read (w,) at current cell
# Careful to disable out-of-bound queries.
# [This may occur if FP-error caused r_next(above) to not enter the lattice;
# auto-rectified at next iteration.]
# Careful to disable out-of-bound queries. (Due to FP-errors.)
idx_I = dr.floor(0.5 * (r_next.o + r.o))
mask = active & dr.all(idx_I >= 0)
mask = active & dr.all(0 <= idx_I) & dr.all(idx_I < N)
decay = dr.gather(type(w), w, flat_index(idx_I), mask)

# Compute constants
Expand All @@ -885,13 +872,7 @@ def wxrt_adjoint(
)

# Update back-projections
dr.scatter_reduce(
dr.ReduceOp.Add,
I,
P * A * B,
flat_index(idx_I),
mask,
)
dr.scatter_reduce(dr.ReduceOp.Add, I, P * A * B, flat_index(idx_I), mask)
d_acc += decay * length

# Walk to next lattice intersection
Expand Down

0 comments on commit 41d8aa1

Please sign in to comment.