Skip to content

Commit

Permalink
Separate basis_function into real and complex implementations to avoi…
Browse files Browse the repository at this point in the history
…d complex to real conversion warnings
  • Loading branch information
sjperkins committed Jan 30, 2024
1 parent c620324 commit 21edff5
Showing 1 changed file with 24 additions and 29 deletions.
53 changes: 24 additions & 29 deletions africanus/model/shape/shapelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,21 @@ def factorial(n):


@numba.jit(nogil=True, nopython=True, cache=True)
def basis_function(n, xx, beta, fourier=False, delta_x=-1):
if fourier:
x = 2 * np.pi * xx
scale = 1.0 / beta
else:
x = xx
scale = beta
def real_basis_function(n, xx, beta, delta_x=-1):
basis_component = 1.0 / np.sqrt(2.0**n * np.sqrt(np.pi) * factorial(n) * beta)
exponential_component = hermite(n, xx / beta) * np.exp(-(xx**2) / (2.0 * beta**2))
return basis_component * exponential_component


@numba.jit(nogil=True, nopython=True, cache=True)
def complex_basis_function(n, xx, beta, delta_x=-1):
x = 2 * np.pi * xx
scale = 1.0 / beta
basis_component = 1.0 / np.sqrt(2.0**n * np.sqrt(np.pi) * factorial(n) * scale)
exponential_component = hermite(n, x / scale) * np.exp(-(x**2) / (2.0 * scale**2))
if fourier:
return (
1.0j**n
* basis_component
* exponential_component
* np.sqrt(2 * np.pi)
/ delta_x
)
else:
return basis_component * exponential_component
return (
1.0j**n * basis_component * exponential_component * np.sqrt(2 * np.pi) / delta_x
)


@numba.jit(nogil=True, nopython=True, cache=True)
Expand Down Expand Up @@ -96,8 +92,8 @@ def shapelet(coords, frequency, coeffs, beta, delta_lm, dtype=np.complex128):
0
if coeffs[src][n1, n2] == 0
else coeffs[src][n1, n2]
* basis_function(n1, fu, beta_u, True, delta_x=delta_l)
* basis_function(n2, fv, beta_v, True, delta_x=delta_m)
* complex_basis_function(n1, fu, beta_u, delta_x=delta_l)
* complex_basis_function(n2, fv, beta_v, delta_x=delta_m)
)
out_shapelets[row, chan, src] = tmp_shapelet
return out_shapelets
Expand Down Expand Up @@ -145,8 +141,8 @@ def shapelet_with_w_term(
0
if coeffs[src][n1, n2] == 0
else coeffs[src][n1, n2]
* basis_function(n1, fu, beta_u, True, delta_x=delta_l)
* basis_function(n2, fv, beta_v, True, delta_x=delta_m)
* complex_basis_function(n1, fu, beta_u, delta_x=delta_l)
* complex_basis_function(n2, fv, beta_v, delta_x=delta_m)
)
w_term = phase_steer_and_w_correct((u, v, w), (l, m), frequency[chan])
out_shapelets[row, chan, src] = tmp_shapelet * w_term
Expand Down Expand Up @@ -183,17 +179,16 @@ def shapelet_1d(u, coeffs, fourier, delta_x=1, beta=1.0):
if delta_x is None:
raise ValueError("You have to pass in a value for delta_x in Fourier mode")
out = np.zeros(nrow, dtype=np.complex128)
for row, ui in enumerate(u):
for n, c in enumerate(coeffs):
out[row] += c * complex_basis_function(n, ui, beta, delta_x=delta_x)
else:
out = np.zeros(nrow, dtype=np.float64)
for row, ui in enumerate(u):
for n, c in enumerate(coeffs):
out[row] += c * basis_function(
n, ui, beta, fourier=fourier, delta_x=delta_x
)
return out
for row, ui in enumerate(u):
for n, c in enumerate(coeffs):
out[row] += c * real_basis_function(n, ui, beta, delta_x=delta_x)


# @numba.jit(nogil=True, nopython=True, cache=True)
return out


def shapelet_2d(u, v, coeffs_l, fourier, delta_x=None, delta_y=None, beta=1.0):
Expand Down

0 comments on commit 21edff5

Please sign in to comment.