Skip to content

Commit

Permalink
Merge branch 'develop'
Browse files Browse the repository at this point in the history
  • Loading branch information
kevincar committed Feb 13, 2024
2 parents 2bbe55e + c8451c0 commit 3becacd
Show file tree
Hide file tree
Showing 13 changed files with 163 additions and 198 deletions.
4 changes: 3 additions & 1 deletion libbids/instruments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from .instrument import Instrument
from .eeg_instrument import EEGInstrument
from .ieeg_instrument import IEEGInstrument
from .instrument import Instrument
from .physio_instrument import PhysioInstrument
from .read_instrument import ReadInstrument
from .stim_instrument import StimInstrument
from .write_instrument import WriteInstrument

EEGInstrument
IEEGInstrument
Instrument
PhysioInstrument
ReadInstrument
Expand Down
110 changes: 88 additions & 22 deletions libbids/instruments/eeg_instrument.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,16 @@ def __init__(
self,
session: "Session",
device: Any,
sfreq: int,
sfreq: Union[int, List[int]],
electrodes: List[str],
physical_dimension: str = "uV",
physical_lim: Tuple = (-1000.0, 1000.0),
preamp_filter: str = "",
record_duration: float = 1.0,
init_read_fn: Union[Tuple[str, list, Dict], Callable] = lambda: None,
read_fn: Union[Tuple[str, list, Dict], Callable] = lambda: None,
stop_fn: Union[Tuple[str, list, Dict], Callable] = lambda: None,
is_digital: bool = False,
**kwargs
):
"""Initialize a device for collecting electroecephalograms
Expand All @@ -45,8 +47,10 @@ def __init__(
Session currently using this instrument
device : Any
The a class that respresents a device for eeg data collection
sfreq : int
The device's sampling frequency
sfreq : Union[int, List[int]]
The device's sampling frequency. All channels/electrodes assume the
same sampling rate if a signle integer is provided, else a list of
sampling rates must be provided for each channel/electrode
electrodes: List[str]
A list of electrode names associated with the device
physical_dimension : str
Expand All @@ -66,14 +70,22 @@ def __init__(
callable, the function is simply called
read_fn : Union[Tuple[str, Dict], Callable]
Similar to `init_read_fn`, but used for sampling data from the device
stop_fn: Union[Tuple[str, List, Dict], Callable]
The function used to stop the actual hardware
is_digital : bool
Whether the data recorded from the device is in a digital format or
a physical floating point integer (e.g., µV)
kwargs : Dict
This keyword arguments dictionary is used to supply detailes to the
edf file header. <See
https://pyedflib.readthedocs.io/en/latest/_modules/pyedflib/edfwriter.html#EdfWriter.setHeader>
"""
super(EEGInstrument, self).__init__(session, Modality.EEG, file_ext="edf")
self.sfreq: int = sfreq
self.sfreqs: List[int] = sfreq if isinstance(sfreq, List) else [sfreq]
assert len(electrodes) > 1, "Must supply electrodes"
assert len(self.sfreqs) == 1 or len(self.sfreqs) == len(
electrodes
), "Must supply same number of sampling rates as electrodes"
self.device: Any = device
self.electrodes: List[str] = electrodes
self.physical_dimension: str = physical_dimension
Expand All @@ -82,9 +94,12 @@ def __init__(
self.record_duration: float = record_duration
self.init_read_fn: Union[Tuple[str, List, Dict], Callable] = init_read_fn
self.read_fn: Union[Tuple[str, List, Dict], Callable] = read_fn
self.stop_fn: Union[Tuple[str, List, Dict], Callable] = stop_fn
self.is_digital: bool = is_digital
self.modality_path.mkdir(exist_ok=True)
self.metadata: Dict = self._fixup_edf_metadata(kwargs)
self.buffer: np.ndarray
self.buffers: List[np.ndarray] = [np.array([]) for i in range(len(self.sfreqs))]

def annotate(self, onset: float, duration: float, description: str):
assert self.writer.writeAnnotation(onset, duration, description) == 0
Expand All @@ -93,23 +108,39 @@ def device_init_read(self):
"""Initializes reading on the device"""
if isinstance(self.init_read_fn, Callable):
self.init_read_fn()
else:
fn, args, kwargs = cast(Tuple, self.init_read_fn)
self.device.__getattribute__(fn)(*args, **kwargs)

fn, args, kwargs = cast(Tuple, self.init_read_fn)
self.device.__getattribute__(fn)(*args, **kwargs)
def device_read(self) -> Union[np.ndarray, List]:
"""read data from the device
def device_read(self) -> np.ndarray:
"""read data from the device"""
Returns
-------
np.ndarray
If all channels share the same sampling rate
List
If not all channels share the same sampling rate
"""
if isinstance(self.read_fn, Callable): # type: ignore
return cast(Callable, self.read_fn)()

fn, args, kwargs = cast(Tuple, self.read_fn)
return self.device.__getattribute__(fn)(*args, **kwargs)

def device_stop(self) -> None:
"""Stop the device"""
if isinstance(self.stop_fn, Callable): # type: ignore
return cast(Callable, self.stop_fn)()

fn, args, kwargs = cast(Tuple, self.stop_fn)
return self.device.__getattribute__(fn)(*args, **kwargs)

def flush(self) -> None:
"""Read data from the device simply to discard"""
self.device_read()

def read(self, remainder: bool = False) -> np.ndarray:
def read(self, remainder: bool = False) -> Union[List, np.ndarray]:
"""Read data from the headset and return the data
Parameters
Expand All @@ -123,18 +154,50 @@ def read(self, remainder: bool = False) -> np.ndarray:
np.ndarray
A 2D array of data in the shape of (channels, time)
"""
# samples
samples: np.ndarray = self.device_read()
self.buffer = np.c_[self.buffer, samples]
if (not remainder) and (self.buffer.shape[1] >= 256):
writebuf: np.ndarray = self.buffer[:, :256]
self.buffer = self.buffer[:, 256:]
self.writer.writeSamples(np.ascontiguousarray(writebuf))
elif remainder and (self.buffer.shape[1] > 0):
writebuf = self.buffer[:, :256]
self.writer.writeSamples(np.ascontiguousarray(writebuf))

return samples
if len(self.sfreqs) == 1:
sfreq: int = self.sfreqs[0]
period: int = int(sfreq * self.record_duration)
samples: np.ndarray = cast(np.ndarray, self.device_read())
self.buffer = np.c_[self.buffer, samples]
if (not remainder) and (self.buffer.shape[1] >= period):
n_periods: int = self.buffer.shape[1] // period
period_boundary: int = n_periods * period
writebuf: np.ndarray = self.buffer[:, :period_boundary]
self.buffer = self.buffer[:, period_boundary:]
self.writer.writeSamples(
np.ascontiguousarray(writebuf), digital=self.is_digital
)
elif remainder and (self.buffer.shape[1] > 0):
writebuf = self.buffer
self.writer.writeSamples(
np.ascontiguousarray(writebuf), digital=self.is_digital
)
return samples
else:
periods: List[int] = [int(f * self.record_duration) for f in self.sfreqs]
ch_samples: List = cast(List, self.device_read())
assert len(ch_samples) == len(
self.sfreqs
), "Data must be the same length as the number sfreqs"
self.buffers = [np.r_[i, j] for i, j in zip(self.buffers, ch_samples)]
period_met: np.bool_ = np.all(
[i.shape[0] >= j for i, j in zip(self.buffers, periods)]
)
has_data: np.bool_ = np.any([i.shape[0] > 0 for i in self.buffers])
if (not remainder) and period_met:
n_periodss: List[int] = [
i.shape[0] // j for i, j in zip(self.buffers, periods)
]
period_boundaries: List = [i * j for i, j in zip(n_periodss, periods)]
writebufs: List = [
i[:j] for i, j in zip(self.buffers, period_boundaries)
]
self.buffers = [i[j:] for i, j in zip(self.buffers, period_boundaries)]
self.writer.writeSamples(writebufs, digital=self.is_digital)
elif remainder and has_data:
writebufs = self.buffers
self.writer.writeSamples(writebufs, digital=self.is_digital)
return ch_samples

def start(self, task: str, run_id: str):
"""Begin recording a run
Expand All @@ -156,6 +219,7 @@ def start(self, task: str, run_id: str):
def stop(self):
"""Stop the run"""
super().stop()
self.device_stop()
self.writer.close()

def _fixup_edf_metadata(self, metadata: Dict):
Expand Down Expand Up @@ -203,6 +267,8 @@ def _initialize_edf_file(self) -> None:
self.writer.setPhysicalDimension(i, self.physical_dimension)
self.writer.setPhysicalMaximum(i, self.physical_lim[0])
self.writer.setPhysicalMinimum(i, self.physical_lim[1])
self.writer.setSamplefrequency(i, self.sfreq)
self.writer.setSamplefrequency(
i, self.sfreqs[0] if len(self.sfreqs) == 1 else self.sfreqs[i]
)
if "AUX" not in el:
self.writer.setPrefilter(i, self.preamp_filter)
Loading

0 comments on commit 3becacd

Please sign in to comment.