Skip to content

Commit

Permalink
Fix logic of fftfreq and rfftfreq
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 6, 2024
1 parent 3b9a84c commit 5af0051
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 58 deletions.
121 changes: 75 additions & 46 deletions brainunit/fft/_fft_change_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,17 @@
# ==============================================================================
from __future__ import annotations

import sys
from typing import Callable, Union, Sequence

import jax
import jax.numpy as jnp
import numpy as np
from jax.numpy import fft as jnpfft
from jaxlib import xla_client

from .. import _unit_common as u
from .._base import Quantity, Unit
from .._base import Quantity, Unit, get_or_create_dimension
from .._misc import set_module_as
from .._unit_common import second
from ..math._fun_change_unit import _fun_change_unit_unary
Expand Down Expand Up @@ -1021,29 +1023,43 @@ def irfftn(
# ---------------------

_time_freq_map = {
u.second: u.hertz,
u.ysecond: u.Yhertz,
u.zsecond: u.Zhertz,
u.asecond: u.Ehertz,
u.fsecond: u.Phertz,
u.psecond: u.Thertz,
u.nsecond: u.Ghertz,
u.usecond: u.Mhertz,
u.msecond: u.khertz,
u.csecond: u.hhertz,
u.dsecond: u.dahertz,
u.dasecond: u.dhertz,
u.hsecond: u.chertz,
u.ksecond: u.mhertz,
u.Msecond: u.uhertz,
u.Gsecond: u.nhertz,
u.Tsecond: u.phertz,
u.Psecond: u.fhertz,
u.Esecond: u.ahertz,
u.Zsecond: u.zhertz,
u.Ysecond: u.yhertz,
0: (u.second, u.hertz),
-24: (u.ysecond, u.Yhertz),
-21: (u.zsecond, u.Zhertz),
-18: (u.asecond, u.Ehertz),
-15: (u.fsecond, u.Phertz),
-12: (u.psecond, u.Thertz),
-9: (u.nsecond, u.Ghertz),
-6: (u.usecond, u.Mhertz),
-3: (u.msecond, u.khertz),
-2: (u.csecond, u.hhertz),
-1: (u.dsecond, u.dahertz),
1: (u.dasecond, u.dhertz),
2: (u.hsecond, u.chertz),
3: (u.ksecond, u.mhertz),
6: (u.Msecond, u.uhertz),
9: (u.Gsecond, u.nhertz),
12: (u.Tsecond, u.phertz),
15: (u.Psecond, u.fhertz),
18: (u.Esecond, u.ahertz),
21: (u.Zsecond, u.zhertz),
24: (u.Ysecond, u.yhertz),
}

def _find_closet_scale(scale):
values = list(_time_freq_map.keys())

diff = np.abs(np.array(values) - scale)

# check if all > 3, return scale
if all(diff > 3):
return scale

# find the closet index
closet_index = diff.argmin()

return values[closet_index]


@set_module_as('brainunit.fft')
def fftfreq(
Expand All @@ -1052,7 +1068,6 @@ def fftfreq(
*,
dtype: jax.typing.DTypeLike | None = None,
device: xla_client.Device | jax.sharding.Sharding | None = None,
target_freq_unit: Unit | None = None
) -> Union[Quantity, jax.typing.ArrayLike]:
"""Return sample frequencies for the discrete Fourier transform.
Expand All @@ -1076,17 +1091,25 @@ def fftfreq(
"""
if isinstance(d, Quantity):
assert d.dim == second.dim, f"Expected time unit, got {d.unit}"
if target_freq_unit is not None:
assert target_freq_unit.dim == u.hertz.dim, f"Expected frequency unit, got {target_freq_unit}"
return Quantity(jnpfft.fftfreq(n, d.mantissa, dtype=dtype, device=device), unit=target_freq_unit)
else:
try:
return Quantity(jnpfft.fftfreq(n, d.mantissa, dtype=dtype, device=device), unit=_time_freq_map[d.unit])
except:
raise TypeError(
f"Cannot convert {d.unit} to common frequency unit, please specify the target frequency unit"
f"by passing the `target_freq_unit` argument.")
return jnpfft.fftfreq(n, d, dtype=dtype, device=device)
time_scale = _find_closet_scale(d.unit.scale)
try:
time_unit, freq_unit = _time_freq_map[time_scale]
except:
time_unit = d.unit
freq_unit_scale = -d.unit.scale
freq_unit = Unit.create(get_or_create_dimension(s=-1),
name=f'10^{freq_unit_scale} hertz',
dispname=f'10^{freq_unit_scale} Hz',
scale=freq_unit_scale,)
try:
return Quantity(jnpfft.fftfreq(n, d.to_decimal(time_unit), dtype=dtype, device=device), unit=freq_unit)
except:
return Quantity(jnpfft.fftfreq(n, d.to_decimal(time_unit), dtype=dtype), unit=freq_unit)
try:
return jnpfft.fftfreq(n, d, dtype=dtype, device=device)
except:
return jnpfft.fftfreq(n, d, dtype=dtype)



@set_module_as('brainunit.fft')
Expand All @@ -1096,7 +1119,6 @@ def rfftfreq(
*,
dtype: jax.typing.DTypeLike | None = None,
device: xla_client.Device | jax.sharding.Sharding | None = None,
target_freq_unit: Unit | None = None
) -> Union[Quantity, jax.typing.ArrayLike]:
"""Return sample frequencies for the discrete Fourier transform.
Expand All @@ -1121,14 +1143,21 @@ def rfftfreq(
"""
if isinstance(d, Quantity):
assert d.dim == second.dim, f"Expected time unit, got {d.unit}"
if target_freq_unit is not None:
assert target_freq_unit.dim == u.hertz.dim, f"Expected frequency unit, got {target_freq_unit}"
return Quantity(jnpfft.rfftfreq(n, d.mantissa, dtype=dtype, device=device), unit=target_freq_unit)
else:
try:
return Quantity(jnpfft.rfftfreq(n, d.mantissa, dtype=dtype, device=device), unit=_time_freq_map[d.unit])
except:
raise TypeError(
f"Cannot convert {d.unit} to common frequency unit, please specify the target frequency unit"
f"by passing the `target_freq_unit` argument.")
return jnpfft.rfftfreq(n, d, dtype=dtype, device=device)
time_scale = _find_closet_scale(d.unit.scale)
try:
time_unit, freq_unit = _time_freq_map[time_scale]
except:
time_unit = d.unit
freq_unit_scale = -d.unit.scale
freq_unit = Unit.create(get_or_create_dimension(s=-1),
name=f'10^{freq_unit_scale} hertz',
dispname=f'10^{freq_unit_scale} Hz',
scale=freq_unit_scale, )
try:
return Quantity(jnpfft.rfftfreq(n, d.to_decimal(time_unit), dtype=dtype, device=device), unit=freq_unit)
except:
return Quantity(jnpfft.rfftfreq(n, d.to_decimal(time_unit), dtype=dtype), unit=freq_unit)
try:
return jnpfft.rfftfreq(n, d, dtype=dtype, device=device)
except:
return jnpfft.rfftfreq(n, d, dtype=dtype)
18 changes: 6 additions & 12 deletions brainunit/fft/_fft_change_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def __init__(self, *args, **kwargs):

def test_time_freq_map(self):
from brainunit.fft._fft_change_unit import _time_freq_map
for key, value in _time_freq_map.items():
for v1, v2 in _time_freq_map.values():
# print(key.scale, value.scale)
assert key.scale == -value.scale
assert v1.scale == -v2.scale

@parameterized.product(
value_axis=[
Expand Down Expand Up @@ -169,17 +169,11 @@ def test_fft_change_unit_freq(self, size, d):
q = d * meter
result = bufft_fun(size, q)

with pytest.raises(AssertionError):
q = d * second
result = bufft_fun(size, q, target_freq_unit=u.meter)

custom_time_unit = Unit.create(get_or_create_dimension(s=1), "custom_second", "cs")
custom_time_unit = Unit.create(get_or_create_dimension(s=1), "custom_second", "cs", scale=100)
custom_hertz_unit = Unit.create(get_or_create_dimension(s=-1), "custom_hertz", "ch", scale=-100)

q = d * custom_time_unit
result = bufft_fun(size, q, target_freq_unit=u.hertz)
result = bufft_fun(size, q)
expected = jnpfft_fun(size, d)
assert_quantity(result, expected, unit=u.hertz)

with pytest.raises(TypeError):
q = d * custom_time_unit
result = bufft_fun(size, q)
assert_quantity(result, expected, unit=custom_hertz_unit)

0 comments on commit 5af0051

Please sign in to comment.