Skip to content

Commit

Permalink
more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
bimac committed Nov 1, 2023
1 parent e849124 commit 6c7af33
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 45 deletions.
45 changes: 39 additions & 6 deletions iblrig/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,19 @@ def get_task_arguments(parents=None):
return _post_parse_arguments(**kwargs)


def _isdatetime(x: str) -> Optional[bool]:
def _is_datetime(x: str) -> bool:
"""
Check if string is a date in the format YYYY-MM-DD.
Check if a string is a date in the format YYYY-MM-DD.
:param x: The string to check
:return: True if the string matches the date format, False otherwise.
:rtype: Optional[bool]
Parameters
----------
x : str
The string to check.
Returns
-------
bool or None
True if the string matches the date format, False otherwise, or None if there's an exception.
"""
try:
datetime.strptime(x, "%Y-%m-%d")
Expand All @@ -104,7 +110,7 @@ def get_session_path(path: Union[str, Path]) -> Optional[Path]:
path = Path(path)
sess = None
for i, p in enumerate(path.parts):
if p.isdigit() and _isdatetime(path.parts[i - 1]):
if p.isdigit() and _is_datetime(path.parts[i - 1]):
sess = Path().joinpath(*path.parts[: i + 1])

return sess
Expand Down Expand Up @@ -210,3 +216,30 @@ def draw_contrast(contrast_set: Iterable[float],
return np.random.choice(contrast_set)
else:
raise ValueError("Unsupported probability_type. Use 'skew_zero', 'biased', or 'uniform'.")


def online_std(new_sample: float, new_count: int, old_mean: float, old_std: float) -> tuple[float, float]:
"""
Updates the mean and standard deviation of a group of values after a sample update
Parameters
----------
new_sample : float
The new sample to be included.
new_count : int
The new count of samples (including new_sample).
old_mean : float
The previous mean (N - 1).
old_std : float
The previous standard deviation (N - 1).
Returns
-------
tuple[float, float]
Updated mean and standard deviation.
"""
if new_count == 1:
return new_sample, 0.0
new_mean = (old_mean * (new_count - 1) + new_sample) / new_count
new_std = np.sqrt((old_std ** 2 * (new_count - 1) + (new_sample - old_mean) * (new_sample - new_mean)) / new_count)
return new_mean, new_std
27 changes: 0 additions & 27 deletions iblrig/online_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,6 @@
sns.set_style('white')


def online_std(new_sample: float, new_count: int, old_mean: float, old_std: float) -> tuple[float, float]:
"""
Updates the mean and standard deviation of a group of values after a sample update
Parameters
----------
new_sample : float
The new sample to be included.
new_count : int
The new count of samples (including new_sample).
old_mean : float
The previous mean (N - 1).
old_std : float
The previous standard deviation (N - 1).
Returns
-------
tuple[float, float]
Updated mean and standard deviation.
"""
if new_count == 1:
return new_sample, 0.0
new_mean = (old_mean * (new_count - 1) + new_sample) / new_count
new_std = np.sqrt((old_std ** 2 * (new_count - 1) + (new_sample - old_mean) * (new_sample - new_mean)) / new_count)
return new_mean, new_std


class DataModel(object):
"""
The data model is a pure numpy / pandas container for the choice world task.
Expand Down
10 changes: 9 additions & 1 deletion iblrig/test/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from scipy import stats

from iblrig import misc
from iblrig.misc import online_std


class TestMisc(unittest.TestCase):
def test_draw_contrast(self):

n_draws = 400
n_contrasts = 10
contrast_set = np.linspace(0, 1, n_contrasts)
Expand All @@ -31,3 +31,11 @@ def assert_distribution(values: list[int], f_exp: list[float] | None = None) ->

self.assertRaises(ValueError, misc.draw_contrast, [], "incorrect_type") # assert exception for incorrect type
self.assertRaises(ValueError, misc.draw_contrast, [0, 1], "biased", 2) # assert exception for out-of-range index

def test_online_std(self):
n = 41
b = np.random.rand(n)
a = b[:-1]
mu, std = online_std(new_sample=b[-1], new_count=n, old_mean=np.mean(a), old_std=np.std(a))
np.testing.assert_almost_equal(std, np.std(b))
np.testing.assert_almost_equal(mu, np.mean(b))
11 changes: 0 additions & 11 deletions iblrig/test/test_online_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,6 @@
matplotlib.use('Agg') # avoid pyqt testing issues


class TestOnlineStd(unittest.TestCase):

def test_online_std(self):
n = 41
b = np.random.rand(n)
a = b[:-1]
mu, std = op.online_std(new_sample=b[-1], count=n, mean=np.mean(a), std=np.std(a))
np.testing.assert_almost_equal(std, np.std(b))
np.testing.assert_almost_equal(mu, np.mean(b))


class TestOnlinePlots(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
Expand Down

0 comments on commit 6c7af33

Please sign in to comment.