diff --git a/src/tof/result.py b/src/tof/result.py index 5b4d6d0..16745dc 100644 --- a/src/tof/result.py +++ b/src/tof/result.py @@ -390,29 +390,45 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def to_nxevent_data(self, key: str) -> sc.DataArray: + def to_nxevent_data(self, key: Optional[str] = None) -> sc.DataArray: """ - Convert a component reading to event data that resembles event data found in a + Convert a detector reading to event data that resembles event data found in a NeXus file. Parameters ---------- key: - Name of the component. + Name of the detector. If ``None``, all detectors are included. """ start = sc.datetime("2024-01-01T12:00:00.000000") period = sc.reciprocal(self.source.frequency) - raw_data = self[key].data.flatten(to='event') - # Select only the neutrons that make it to the detector - event_data = raw_data[~raw_data.masks['blocked_by_others']].copy() + + keys = list(self._detectors.keys()) if key is None else [key] + + event_data = [] + for name in keys: + raw_data = self._detectors[name].data.flatten(to='event') + events = ( + raw_data[~raw_data.masks['blocked_by_others']] + .copy() + .drop_masks('blocked_by_others') + ) + events.coords['distance'] = sc.broadcast( + events.coords['distance'], sizes=events.sizes + ).copy() + event_data.append(events) + + event_data = sc.concat(event_data, dim=event_data[0].dim) dt = period.to(unit=event_data.coords['toa'].unit) event_time_zero = (dt * (event_data.coords['toa'] // dt)).to(dtype=int) + start event_data.coords['event_time_zero'] = event_time_zero event_data.coords['event_time_offset'] = event_data.coords.pop( 'toa' ) % period.to(unit=dt.unit) - return ( + out = ( event_data.drop_coords(['tof', 'speed', 'time', 'wavelength']) - .group('event_time_zero') + .group('distance', 'event_time_zero') .rename_dims(event_time_zero='pulse') - ) + ).rename_dims(distance='detector_number') + out.coords['Ltotal'] = out.coords.pop('distance') + return out diff --git a/tests/model_test.py b/tests/model_test.py index 99512e7..0d551eb 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -428,3 +428,7 @@ def test_to_nxevent_data(): assert nxevent_data.bins.concat().value.coords[ 'event_time_offset' ].max() <= sc.reciprocal(source.frequency).to(unit='us') + + # Test when we include all detectors at once + nxevent_data = res.to_nxevent_data() + assert nxevent_data.sizes == {'detector_number': 2, 'pulse': 2} diff --git a/tests/source_test.py b/tests/source_test.py index 71fd492..6947e3b 100644 --- a/tests/source_test.py +++ b/tests/source_test.py @@ -104,24 +104,9 @@ def test_creation_from_distribution(): assert sc.isclose(mid / right, sc.scalar(10.0), rtol=rtol) # Make sure distribution is monotonically increasing - locs = np.linspace(1.0, 4.0, 20) - step = 0.5 * (locs[1] - locs[0]) - for i in range(len(locs) - 2): - a = da.hist( - wavelength=sc.array( - dims=['wavelength'], - values=[locs[i] - step, locs[i] + step], - unit='angstrom', - ) - ).data.sum() - b = da.hist( - wavelength=sc.array( - dims=['wavelength'], - values=[locs[i + 1] - step, locs[i + 1] + step], - unit='angstrom', - ) - ).data.sum() - assert b > a + h = da.hist(wavelength=10) + diff = h.data[1:] - h.data[:-1] + assert sc.all(diff > sc.scalar(0.0, unit='counts')) def test_non_integer_sampling():