From c9d0c930cc264c48f534a6131004e58ff0c62802 Mon Sep 17 00:00:00 2001 From: Simon Perkins Date: Thu, 31 Oct 2024 10:46:58 +0200 Subject: [PATCH] Optimise beam cube (#320) --- HISTORY.rst | 1 + africanus/experimental/rime/fused/core.py | 13 +- .../experimental/rime/fused/terms/cube_dde.py | 369 ++++++++++-------- 3 files changed, 219 insertions(+), 164 deletions(-) diff --git a/HISTORY.rst b/HISTORY.rst index cc3c2771..ea25a02e 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,6 +4,7 @@ History 0.3.8 (2024-09-29) ------------------ +* Optimise the beam cube implementation (:pr:`320`) * Support an `init_state` argument into both `Term.init_fields` and `Transformer.init_fields` (:pr:`319`) * Use virtualenv to setup github CI test environments (:pr:`321`) diff --git a/africanus/experimental/rime/fused/core.py b/africanus/experimental/rime/fused/core.py index a6595a8f..362197ad 100644 --- a/africanus/experimental/rime/fused/core.py +++ b/africanus/experimental/rime/fused/core.py @@ -41,7 +41,7 @@ def rime_impl(*args): def nb_rime(*args): if not len(args) > 0: raise TypeError( - "rime must be at least be called " "with the signature argument" + "rime must be at least be called with the signature argument" ) if not isinstance(args[0], types.Literal): @@ -103,13 +103,13 @@ def impl(*args): for ch in range(nchan): X = term_sampler(state, s, r, t, f1, f2, a1, a2, ch) - for c, value in enumerate(numba.literal_unroll(X)): + for co, value in enumerate(numba.literal_unroll(X)): # Kahan summation - y = value - compensation[r, ch, c] - current = vis[r, ch, c] + y = value - compensation[r, ch, co] + current = vis[r, ch, co] x = current + y - compensation[r, ch, c] = (x - current) - y - vis[r, ch, c] = x + compensation[r, ch, co] = (x - current) - y + vis[r, ch, co] = x return vis @@ -204,7 +204,6 @@ def dask_blockwise_args(self, **kwargs): def __call__(self, time, antenna1, antenna2, feed1, feed2, **kwargs): keys = self.REQUIRED_ARGS_LITERAL + tuple(map(types.literal, kwargs.keys())) - args = keys + (time, antenna1, antenna2, feed1, feed2) + tuple(kwargs.values()) return self.impl(types.literal(self.rime_spec.spec_hash), *args) diff --git a/africanus/experimental/rime/fused/terms/cube_dde.py b/africanus/experimental/rime/fused/terms/cube_dde.py index 78eeb782..aaa87e8c 100644 --- a/africanus/experimental/rime/fused/terms/cube_dde.py +++ b/africanus/experimental/rime/fused/terms/cube_dde.py @@ -1,5 +1,6 @@ from collections import namedtuple +import numba from numba.core import cgutils, types from numba.extending import intrinsic from numba.cpython.unsafe.tuple import tuple_setitem @@ -79,13 +80,14 @@ def init_fields( beam_antenna_scaling=None, ): ncorr = len(self.corrs) + zero_vis = zero_vis_factory(ncorr) ex_dtype = beam_lm_extents.dtype beam_info_types = [ex_dtype] * 2 + [types.int64] * 2 + [types.float64] * 2 beam_info_type = types.NamedTuple(beam_info_types, BeamInfo) fields = [ - ("beam_freq_data", chan_freq.copy(ndim=2)), - ("beam_info", beam_info_type), + # source, time, feed, antenna, chan, corr + ("sampled_beam", beam.copy(ndim=6)) ] def beam( @@ -102,15 +104,16 @@ def beam( if beam.shape[3] != ncorr: raise ValueError("Beam correlations don't match specification corrs") - freq_data = np.empty((chan_freq.shape[0], 3), chan_freq.dtype) + nchan = chan_freq.shape[0] + freq_data = np.empty((nchan, 3), chan_freq.dtype) beam_nud = beam_freq_map.shape[0] beam_lw, beam_mh, beam_nud = beam.shape[:3] if beam_lw < 2 or beam_mh < 2 or beam_nud < 2: raise ValueError("beam_lw, beam_mh and beam_nud must be >= 2") - for f in range(chan_freq.shape[0]): - freq = chan_freq[f] + for c in range(nchan): + freq = chan_freq[c] lower = 0 upper = beam_nud - 1 @@ -132,20 +135,20 @@ def beam( # Set up scaling, lower weight, lower grid pos if lower == -1: - freq_data[f, 0] = freq / beam_freq_map[0] - freq_data[f, 1] = 1.0 - freq_data[f, 2] = 0 + freq_data[c, 0] = freq / beam_freq_map[0] + freq_data[c, 1] = 1.0 + freq_data[c, 2] = 0 elif upper == beam_nud: - freq_data[f, 0] = freq / beam_freq_map[beam_nud - 1] - freq_data[f, 1] = 0.0 - freq_data[f, 2] = beam_nud - 2 + freq_data[c, 0] = freq / beam_freq_map[beam_nud - 1] + freq_data[c, 1] = 0.0 + freq_data[c, 2] = beam_nud - 2 else: - freq_data[f, 0] = 1.0 + freq_data[c, 0] = 1.0 freq_low = beam_freq_map[lower] freq_high = beam_freq_map[upper] freq_diff = freq_high - freq_low - freq_data[f, 1] = (freq_high - freq) / freq_diff - freq_data[f, 2] = lower + freq_data[c, 1] = (freq_high - freq) / freq_diff + freq_data[c, 2] = lower # Beam Extents lower_l, upper_l = beam_lm_extents[0] @@ -161,7 +164,195 @@ def beam( mscale = mmaxf / (upper_m - lower_m) beam_info = BeamInfo(lscale, mscale, lmaxi, mmaxi, lmaxf, mmaxf) - return freq_data, beam_info + + nsrc = lm.shape[0] + ntime = len(init_state.utime) + nfeed = len(init_state.ufeed) + nantenna = len(init_state.uantenna) + + sampled_beam = np.empty( + (nsrc, ntime, nfeed, nantenna, nchan, ncorr), beam.dtype + ) + + for s in range(nsrc): + l = lm[s, 0] # noqa + m = lm[s, 1] + for t in range(ntime): + for f in range(nfeed): + for a in range(nantenna): + sin_pa = beam_parangle[t, f, a, 0] + cos_pa = beam_parangle[t, f, a, 1] + + for c in range(nchan): + # Unpack frequency data + freq_scale = freq_data[c, 0] + # lower and upper frequency weights + nud = freq_data[c, 1] + inv_nud = freq_data.dtype.type(1.0) - nud + # lower and upper frequency grid position + gc0 = np.int32(freq_data[c, 2]) + gc1 = gc0 + np.int32(1) + + # Apply any frequency scaling + sl = l * freq_scale + sm = m * freq_scale + + # Add pointing errors + # tl = sl + point_errors[t, a, c, 0] + # tm = sm + point_errors[t, a, c, 1] + tl = sl + tm = sm + + # Rotate lm coordinate angle + vl = tl * cos_pa - tm * sin_pa + vm = tl * sin_pa + tm * cos_pa + + # Scale by antenna scaling + # vl *= antenna_scaling[a, f, 0] + # vm *= antenna_scaling[a, f, 1] + + # Shift into the cube coordinate system + vl = beam_info.lscale * (vl - lower_l) + vm = beam_info.mscale * (vm - lower_m) + + # Clamp the coordinates to the edges of the cube + vl = max(0.0, min(vl, beam_info.lmaxf)) + vm = max(0.0, min(vm, beam_info.mmaxf)) + + # Snap to the lower grid coordinates + gl0 = np.int32(np.floor(vl)) + gm0 = np.int32(np.floor(vm)) + + # Snap to the upper grid coordinates + gl1 = min(gl0 + np.int32(1), beam_info.lmaxi) + gm1 = min(gm0 + np.int32(1), beam_info.mmaxi) + + # Difference between grid and offset coordinates + ld = vl - gl0 + md = vm - gm0 + + corr_sum = zero_vis(beam.dtype.type(0)) + absc_sum = zero_vis(beam.real.dtype.type(0)) + + # Lower cube + weight = (1.0 - ld) * (1.0 - md) * nud + + for co in range(ncorr): + value = beam[gl0, gm0, gc0, co] + absc_sum = tuple_setitem( + absc_sum, + co, + weight * np.abs(value) + absc_sum[co], + ) + corr_sum = tuple_setitem( + corr_sum, co, weight * value + corr_sum[co] + ) + + weight = ld * (1.0 - md) * nud + + for co in range(ncorr): + value = beam[gl1, gm0, gc0, co] + absc_sum = tuple_setitem( + absc_sum, + co, + weight * np.abs(value) + absc_sum[co], + ) + corr_sum = tuple_setitem( + corr_sum, co, weight * value + corr_sum[co] + ) + + weight = (1.0 - ld) * md * nud + + for co in range(ncorr): + value = beam[gl0, gm1, gc0, co] + absc_sum = tuple_setitem( + absc_sum, + co, + weight * np.abs(value) + absc_sum[co], + ) + corr_sum = tuple_setitem( + corr_sum, co, weight * value + corr_sum[co] + ) + + weight = ld * md * nud + + for co in range(ncorr): + value = beam[gl1, gm1, gc0, co] + absc_sum = tuple_setitem( + absc_sum, + co, + weight * np.abs(value) + absc_sum[co], + ) + corr_sum = tuple_setitem( + corr_sum, co, weight * value + corr_sum[co] + ) + + # Upper cube + weight = (1.0 - ld) * (1.0 - md) * inv_nud + + for co in range(ncorr): + value = beam[gl0, gm0, gc1, co] + absc_sum = tuple_setitem( + absc_sum, + co, + weight * np.abs(value) + absc_sum[co], + ) + corr_sum = tuple_setitem( + corr_sum, co, weight * value + corr_sum[co] + ) + + weight = ld * (1.0 - md) * inv_nud + + for co in range(ncorr): + value = beam[gl1, gm0, gc1, co] + absc_sum = tuple_setitem( + absc_sum, + co, + weight * np.abs(value) + absc_sum[co], + ) + corr_sum = tuple_setitem( + corr_sum, co, weight * value + corr_sum[co] + ) + + weight = (1.0 - ld) * md * inv_nud + + for co in range(ncorr): + value = beam[gl0, gm1, gc1, co] + absc_sum = tuple_setitem( + absc_sum, + co, + weight * np.abs(value) + absc_sum[co], + ) + corr_sum = tuple_setitem( + corr_sum, co, weight * value + corr_sum[co] + ) + + weight = ld * md * inv_nud + + for co in range(ncorr): + value = beam[gl1, gm1, gc1, co] + absc_sum = tuple_setitem( + absc_sum, + co, + weight * np.abs(value) + absc_sum[co], + ) + corr_sum = tuple_setitem( + corr_sum, co, weight * value + corr_sum[co] + ) + + for co in range(ncorr): + div = np.abs(corr_sum[co]) + value = corr_sum[co] * absc_sum[co] + + if div != 0.0: + value /= div + + corr_sum = tuple_setitem(corr_sum, co, value) + + for co in range(ncorr): + sampled_beam[s, t, f, a, c, co] = corr_sum[co] + + return sampled_beam return fields, beam @@ -172,150 +363,14 @@ def sampler(self): def cube_dde(state, s, r, t, f1, f2, a1, a2, c): a = state.antenna1_inverse[r] if left else state.antenna2_inverse[r] - feed = state.feed1_inverse[r] if left else state.feed2_inverse[r] - sin_pa = state.beam_parangle[t, feed, a, 0] - cos_pa = state.beam_parangle[t, feed, a, 1] - - l = state.lm[s, 0] # noqa - m = state.lm[s, 1] - - # Unpack frequency data - freq_scale = state.beam_freq_data[c, 0] - # lower and upper frequency weights - nud = state.beam_freq_data[c, 1] - inv_nud = state.beam_freq_data.dtype.type(1.0) - nud - # lower and upper frequency grid position - gc0 = np.int32(state.beam_freq_data[c, 2]) - gc1 = gc0 + np.int32(1) - - # Apply any frequency scaling - sl = l * freq_scale - sm = m * freq_scale - - # Add pointing errors - # tl = sl + point_errors[t, a, c, 0] - # tm = sm + point_errors[t, a, c, 1] - tl = sl - tm = sm - - # Rotate lm coordinate angle - vl = tl * cos_pa - tm * sin_pa - vm = tl * sin_pa + tm * cos_pa - - # Scale by antenna scaling - # vl *= antenna_scaling[a, f, 0] - # vm *= antenna_scaling[a, f, 1] - - # Beam Extents - lower_l, upper_l = state.beam_lm_extents[0] - lower_m, upper_m = state.beam_lm_extents[1] - - # Shift into the cube coordinate system - vl = state.beam_info.lscale * (vl - lower_l) - vm = state.beam_info.mscale * (vm - lower_m) - - # Clamp the coordinates to the edges of the cube - vl = max(0.0, min(vl, state.beam_info.lmaxf)) - vm = max(0.0, min(vm, state.beam_info.mmaxf)) - - # Snap to the lower grid coordinates - gl0 = np.int32(np.floor(vl)) - gm0 = np.int32(np.floor(vm)) - - # Snap to the upper grid coordinates - gl1 = min(gl0 + np.int32(1), state.beam_info.lmaxi) - gm1 = min(gm0 + np.int32(1), state.beam_info.mmaxi) - - # Difference between grid and offset coordinates - ld = vl - gl0 - md = vm - gm0 - - corr_sum = zero_vis(state.beam.dtype.type(0)) - absc_sum = zero_vis(state.beam.real.dtype.type(0)) - - # Lower cube - weight = (1.0 - ld) * (1.0 - md) * nud - - for co in range(ncorr): - value = state.beam[gl0, gm0, gc0, co] - absc_sum = tuple_setitem( - absc_sum, co, weight * np.abs(value) + absc_sum[co] - ) - corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - - weight = ld * (1.0 - md) * nud - - for co in range(ncorr): - value = state.beam[gl1, gm0, gc0, co] - absc_sum = tuple_setitem( - absc_sum, co, weight * np.abs(value) + absc_sum[co] - ) - corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) + f = state.feed1_inverse[r] if left else state.feed2_inverse[r] + result = zero_vis(state.beam.dtype.type(0)) - weight = (1.0 - ld) * md * nud - - for co in range(ncorr): - value = state.beam[gl0, gm1, gc0, co] - absc_sum = tuple_setitem( - absc_sum, co, weight * np.abs(value) + absc_sum[co] - ) - corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - - weight = ld * md * nud - - for co in range(ncorr): - value = state.beam[gl1, gm1, gc0, co] - absc_sum = tuple_setitem( - absc_sum, co, weight * np.abs(value) + absc_sum[co] - ) - corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - - # Upper cube - weight = (1.0 - ld) * (1.0 - md) * inv_nud - - for co in range(ncorr): - value = state.beam[gl0, gm0, gc1, co] - absc_sum = tuple_setitem( - absc_sum, co, weight * np.abs(value) + absc_sum[co] - ) - corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - - weight = ld * (1.0 - md) * inv_nud - - for co in range(ncorr): - value = state.beam[gl1, gm0, gc1, co] - absc_sum = tuple_setitem( - absc_sum, co, weight * np.abs(value) + absc_sum[co] + for co in numba.literal_unroll(range(ncorr)): + result = tuple_setitem( + result, co, state.sampled_beam[s, t, f, a, c, co] ) - corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - - weight = (1.0 - ld) * md * inv_nud - - for co in range(ncorr): - value = state.beam[gl0, gm1, gc1, co] - absc_sum = tuple_setitem( - absc_sum, co, weight * np.abs(value) + absc_sum[co] - ) - corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - - weight = ld * md * inv_nud - - for co in range(ncorr): - value = state.beam[gl1, gm1, gc1, co] - absc_sum = tuple_setitem( - absc_sum, co, weight * np.abs(value) + absc_sum[co] - ) - corr_sum = tuple_setitem(corr_sum, co, weight * value + corr_sum[co]) - - for co in range(ncorr): - div = np.abs(corr_sum[co]) - value = corr_sum[co] * absc_sum[co] - - if div != 0.0: - value /= div - - corr_sum = tuple_setitem(corr_sum, co, value) - return corr_sum + return result return cube_dde