Skip to content

Commit

Permalink
chore(lib): postpone UTD for a feature PR
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Dec 20, 2024
1 parent b0068f2 commit 466a99c
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 110 deletions.
141 changes: 75 additions & 66 deletions differt/src/differt/em/_utd.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ruff: noqa: N802, N806
from functools import partial
from typing import overload
from typing import Any, Literal, overload

import equinox as eqx
import jax
Expand All @@ -8,6 +9,8 @@
from beartype import beartype as typechecker
from jaxtyping import Array, Complex, Float, jaxtyped

from differt.utils import dot


@partial(jax.jit, inline=True)
@jaxtyped(typechecker=typechecker)
Expand All @@ -22,6 +25,25 @@ def _sign(x: Float[Array, " *batch"]) -> Float[Array, " *batch"]:
return jnp.where(x >= 0, ones, -ones)

Check warning on line 25 in differt/src/differt/em/_utd.py

View check run for this annotation

Codecov / codecov/patch

differt/src/differt/em/_utd.py#L24-L25

Added lines #L24 - L25 were not covered by tests


@partial(jax.jit, inline=True, static_argnames=("mode"))
@jaxtyped(typechecker=typechecker)
def _N(
beta: Float[Array, " *#batch"], n: Float[Array, " *#batch"], mode: Literal["+", "-"]
) -> Float[Array, " *batch"]:
if mode == "+":
return jnp.round((beta + jnp.pi) / (2 * n * jnp.pi))
return jnp.round((beta + jnp.pi) / (2 * n * jnp.pi))

Check warning on line 35 in differt/src/differt/em/_utd.py

View check run for this annotation

Codecov / codecov/patch

differt/src/differt/em/_utd.py#L33-L35

Added lines #L33 - L35 were not covered by tests


@partial(jax.jit, inline=True, static_argnames=("mode"))
@jaxtyped(typechecker=typechecker)
def _a(
beta: Float[Array, " *#batch"], n: Float[Array, " *#batch"], mode: Literal["+", "-"]
) -> Float[Array, " *batch"]:
N = _N(beta, n, mode)
return 2.0 * jax.lax.integer_pow(jnp.cos(0.5 * (2 * n * jnp.pi * N - beta)), 2)

Check warning on line 44 in differt/src/differt/em/_utd.py

View check run for this annotation

Codecov / codecov/patch

differt/src/differt/em/_utd.py#L43-L44

Added lines #L43 - L44 were not covered by tests


@overload
def L_i(
s_d: Float[Array, " *#batch"],
Expand Down Expand Up @@ -57,7 +79,7 @@ def L_i(

@eqx.filter_jit
@jaxtyped(typechecker=typechecker)
def L_i(
def L_i( # noqa: PLR0917
s_d: Float[Array, " *#batch"],
sin_2_beta_0: Float[Array, " *#batch"],
rho_1_i: Float[Array, " *#batch"] | None = None,
Expand All @@ -68,6 +90,13 @@ def L_i(
r"""
Compute the distance parameter associated with the incident shadow boundaries.
.. note::
This function can also be used to compute the distance parameters
associated with the reflection shadow boundaries for the o- and n-faces,
by passing the corresponding radii of curvature, see
:cite:`utd-mcnamara{eq. 6.28, p. 270}`.
Its general expression is given by :cite:`utd-mcnamara{eq. 6.25, p. 270}`:
.. math::
Expand Down Expand Up @@ -112,6 +141,10 @@ def L_i(
Returns:
The values of the distance parameter :math:`L_i`.
Raises:
ValueError: If 's_i' was provided along at least one of the other radius parameters,
or if one or the three 'rho' parameters was not provided.
"""
radii = (rho_1_i, rho_2_i, rho_e_i)
all_none = all(x is None for x in radii)
Expand All @@ -135,7 +168,7 @@ def L_i(

@jax.jit
@jaxtyped(typechecker=typechecker)
def F(z: Float[Array, " *batch"]) -> Complex[Array, " *batch"]: # noqa: N802
def F(z: Float[Array, " *batch"]) -> Complex[Array, " *batch"]:
r"""
Evaluate the transition function at the given points.
Expand Down Expand Up @@ -197,19 +230,18 @@ def F(z: Float[Array, " *batch"]) -> Complex[Array, " *batch"]: # noqa: N802


@jax.jit
def rays_to_sin_angles(incident_rays, diffracted_rays, edge_vectors) -> None:
"""Compute the sin of angles..."""
incident_rays, _ = normalize(incident_rays)
diffracted_rays, _ = normalize(diffracted_rays)


@jax.jit
@jaxtyped(typechecker=typechecker)
def diffraction_coefficients(
incident_ray, diffracted_ray, edge_vector, k, n, r_prime, r, r0
):
*_args: Any,
) -> None:
"""
Compute the diffraction coefficients based on the Uniform Theory of Diffraction.
Warning:
This function is not yet implemented, as we are still thinking of the
best API for it. If you want to get involved in the implementation of UTD coefficients,
please reach out to us on GitHub!
The implementation closely follows what is described
in :cite:`utd-mcnamara{p. 268-273}`.
Expand All @@ -223,7 +255,16 @@ def diffraction_coefficients(
rho_1_i: ...
rho_1_i: ...
rho_e_i: ...
Returns:
The soft and hard diffraction coefficients.
Raises:
NotImplementedError: The function is not yet implemented.
"""
# ruff: noqa: ERA001, F821, F841
raise NotImplementedError

# Ensure input vectors are normalized
incident_ray = incident_ray / jnp.linalg.norm(incident_ray)
diffracted_ray = diffracted_ray / jnp.linalg.norm(diffracted_ray)
Expand All @@ -238,65 +279,33 @@ def diffraction_coefficients(
L = r * jnp.sin(beta) ** 2 / (r + r_prime)
L_prime = r_prime * jnp.sin(beta_0) ** 2 / (r + r_prime)

# Compute the cotangent arguments
cot_arg1 = (phi + (beta - beta_0)) / (2 * n)
cot_arg2 = (phi - (beta - beta_0)) / (2 * n)
cot_arg3 = (phi + (beta + beta_0)) / (2 * n)
cot_arg4 = (phi - (beta + beta_0)) / (2 * n)

# Define the cotangent function
def cot(x):
return 1.0 / jnp.tan(x)

# Compute the a± coefficients
1 + jnp.cos(2 * n * jnp.pi - (phi + beta - beta_0))
1 + jnp.cos(2 * n * jnp.pi - (phi - beta + beta_0))

# Compute the D_s and D_h functions
def D_soft(L, cot_arg):
return (
-jnp.exp(-1j * jnp.pi / 4)
/ (2 * n * jnp.sqrt(2 * jnp.pi * k))
* cot(cot_arg)
* F(k * L * jnp.power(jnp.sin(cot_arg), 2))
)

def D_hard(L, cot_arg):
return (
-jnp.exp(-1j * jnp.pi / 4)
/ (2 * n * jnp.sqrt(2 * jnp.pi * k))
* cot(cot_arg)
* F(k * L * jnp.power(jnp.sin(cot_arg), 2))
)

# Compute the diffraction coefficients
D_s = (
D_soft(L, cot_arg1)
+ D_soft(L, cot_arg2)
+ D_soft(L_prime, cot_arg3)
+ D_soft(L_prime, cot_arg4)
)
phi_i = jnp.pi - (jnp.pi - jnp.arccos(dot(-s_t_i, t_o))) * _sign(dot(-s_t_i, n_o))
phi_d = jnp.pi - (jnp.pi - jnp.arccos(dot(+s_t_d, t_o))) * _sign(dot(+s_d_i, n_o))

D_h = (
D_hard(L, cot_arg1)
+ D_hard(L, cot_arg2)
+ D_hard(L_prime, cot_arg3)
+ D_hard(L_prime, cot_arg4)
# Compute the angle differences
phi_1 = phi_d - phi_i
phi_2 = phi_d + phi_i

# Compute the diffraction coefficients (without common mul. factor)
D_1 = _cot((jnp.pi + phi_1) / (2 * n)) * F(k * L_i * _a(phi_1, "+"))
D_2 = _cot((jnp.pi - phi_1) / (2 * n)) * F(k * L_i * _a(phi_1, "-"))
D_3 = _cot((jnp.pi + phi_2) / (2 * n)) * F(k * L_r_n * _a(phi_2, "+"))
D_4 = _cot((jnp.pi - phi_2) / (2 * n)) * F(k * L_r_o * _a(phi_2, "-"))

factor = -jnp.exp(-1j * jnp.pi / 4) / (
2 * n * jnp.sqrt(2 * jnp.pi * k) * sin_beta_0
)

# Apply the Keller cone condition
D_s = jnp.where(jnp.abs(jnp.sin(beta) - jnp.sin(beta_0)) < 1e-6, D_s, 0)
D_h = jnp.where(jnp.abs(jnp.sin(beta) - jnp.sin(beta_0)) < 1e-6, D_h, 0)

# Construct the dyadic diffraction coefficient matrix
jnp.array([[D_s, 0, 0], [0, D_h, 0], [0, 0, 0]], dtype=jnp.complex64)
# D_s = jnp.where(jnp.abs(jnp.sin(beta) - jnp.sin(beta_0)) < 1e-6, D_s, 0)
# D_h = jnp.where(jnp.abs(jnp.sin(beta) - jnp.sin(beta_0)) < 1e-6, D_h, 0)

# s_p
# TODO: below are assuming perfectly conducting surfaces

d_12 = d_1 + d_2
d_34 = d_3 + d_4
D_12 = D_1 + D_2
D_34 = D_3 + D_4

d_s = d_12 - d_34
d_h = d_12 + d_34
D_s = (D_12 - D_34) * factor
D_h = (D_12 + D_34) * factor

return d_h, d_s
return D_s, D_h
6 changes: 5 additions & 1 deletion differt/src/differt/geometry/_triangle_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ def set_face_colors(

return self.set_face_colors(colors=colors)

Check warning on line 526 in differt/src/differt/geometry/_triangle_mesh.py

View check run for this annotation

Codecov / codecov/patch

differt/src/differt/geometry/_triangle_mesh.py#L526

Added line #L526 was not covered by tests

face_colors = jnp.broadcast_to(colors.reshape(-1, 3), self.triangles.shape)
# TODO: understand why pyright cannot determine that colors is not None
face_colors = jnp.broadcast_to(
colors.reshape(-1, 3), # type: ignore[reportOptionalMemberAccess]
self.triangles.shape,
)
return eqx.tree_at(
lambda m: m.face_colors,
self,
Expand Down
46 changes: 4 additions & 42 deletions differt/tests/em/test_utd.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def test_F() -> None: # noqa: N802
info = jnp.finfo(float)
got = F(info.eps)
mag = jnp.abs(got)
ang = jnp.angle(got, deg=True)
angle = jnp.angle(got, deg=True)

chex.assert_trees_all_close(mag, 0.0, atol=1e-7)
chex.assert_trees_all_close(ang, 45)
chex.assert_trees_all_close(angle, 45)

# Test case 3: F(x), x -> +oo
got = F(1e6)
Expand All @@ -102,43 +102,5 @@ def test_F() -> None: # noqa: N802


def test_diffraction_coefficients() -> None:
# Test case 1: Normal diffraction
incident_ray = jnp.array([1.0, 0.0, 0.0])
diffracted_ray = jnp.array([0.0, 1.0, 0.0])
edge_vector = jnp.array([0.0, 0.0, 1.0])
k = 2 * jnp.pi # Assuming wavelength = 1
n = 1.5
r_prime = 10.0
r = 20.0
r0 = 5.0
result = diffraction_coefficients(
incident_ray, diffracted_ray, edge_vector, k, n, r_prime, r, r0
)
assert result.shape == (3, 3)
assert jnp.iscomplexobj(result)

# Test case 2: Grazing incidence
incident_ray = jnp.array([0.0, 1.0, 0.0])
diffracted_ray = jnp.array([1.0, 0.0, 0.0])
result = diffraction_coefficients(
incident_ray, diffracted_ray, edge_vector, k, n, r_prime, r, r0
)
assert jnp.allclose(
result, jnp.zeros((3, 3)), atol=1e-6
) # Should be zero for grazing incidence

# Test case 3: Random inputs
for _ in range(10):
incident_ray = random.normal(key, (3,))
diffracted_ray = random.normal(key, (3,))
edge_vector = random.normal(key, (3,))
k = random.uniform(key, (), minval=1, maxval=10)
n = random.uniform(key, (), minval=1, maxval=2)
r_prime = random.uniform(key, (), minval=1, maxval=20)
r = random.uniform(key, (), minval=1, maxval=20)
r0 = random.uniform(key, (), minval=1, maxval=10)
result = diffraction_coefficients(
incident_ray, diffracted_ray, edge_vector, k, n, r_prime, r, r0
)
assert result.shape == (3, 3)
assert jnp.iscomplexobj(result)
with pytest.raises(NotImplementedError):
_ = diffraction_coefficients()
2 changes: 1 addition & 1 deletion docs/source/reference/differt.em.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ can be obtained to express the diffraction field in function of the incident fie
:cite:`utd-mcnamara{eq. 6.13, p. 268}`:

.. math::
\boldsymbol{E}^d(P) = \boldsymbol{E}^d(Q_d) \sqrt{\frac{\rho^d}{\left(\rho_1^d+s^r\right)\left(\rho_2^r+s^r\right)}} e^{-jks^d},
\boldsymbol{E}^d(P) = \boldsymbol{E}^d(Q_d) \sqrt{\frac{\rho^d}{\left(\rho_1^d+s^d\right)\left(\rho_2^r+s^r\right)}} e^{-jks^d},
where :math:`P` is the observation point and :math:`Q_d` is the diffraction point on the edge, :math:`\rho^d` is the edge caustic distance, :math:`k` is the wavenumber, and :math:`s_d` is the distance between :math:`Q_r` and :math:`P`. Moreover, :math:`\boldsymbol{E}^d(Q_d)` can be expressed in terms of the incident field :math:`\boldsymbol{E}^i`:

Expand Down

0 comments on commit 466a99c

Please sign in to comment.