diff --git a/mne/io/constants.py b/mne/io/constants.py index 4c4f3b57158..682cf091a12 100644 --- a/mne/io/constants.py +++ b/mne/io/constants.py @@ -916,6 +916,9 @@ FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE = 304 # fNIRS frequency domain AC amplitude FIFF.FIFFV_COIL_FNIRS_FD_PHASE = 305 # fNIRS frequency domain phase FIFF.FIFFV_COIL_FNIRS_RAW = FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE # old alias +FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE = 306 # fNIRS time-domain gated amplitude +FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE = 307 # fNIRS time-domain moments amplitude +FIFF.FIFFV_COIL_FNIRS_PROCESSED = 308 # fNIRS processed data FIFF.FIFFV_COIL_MCG_42 = 1000 # For testing the MCG software @@ -1002,7 +1005,9 @@ FIFF.FIFFV_COIL_DIPOLE, FIFF.FIFFV_COIL_FNIRS_HBO, FIFF.FIFFV_COIL_FNIRS_HBR, FIFF.FIFFV_COIL_FNIRS_RAW, FIFF.FIFFV_COIL_FNIRS_OD, FIFF.FIFFV_COIL_FNIRS_FD_AC_AMPLITUDE, - FIFF.FIFFV_COIL_FNIRS_FD_PHASE, FIFF.FIFFV_COIL_MCG_42, + FIFF.FIFFV_COIL_FNIRS_FD_PHASE, FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE, + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE, FIFF.FIFFV_COIL_FNIRS_PROCESSED, + FIFF.FIFFV_COIL_MCG_42, FIFF.FIFFV_COIL_POINT_MAGNETOMETER, FIFF.FIFFV_COIL_AXIAL_GRAD_5CM, FIFF.FIFFV_COIL_VV_PLANAR_W, FIFF.FIFFV_COIL_VV_PLANAR_T1, FIFF.FIFFV_COIL_VV_PLANAR_T2, FIFF.FIFFV_COIL_VV_PLANAR_T3, @@ -1059,3 +1064,22 @@ # add a comment here (with doi of a published source) above any new # aliases, as they are added } + +# SNIRF: Supported measurementList(k).dataTypeLabel values in dataTimeSeries +FNIRS_SNIRF_DATATYPELABELS = { + # These types are specified `here `_ + "HbO": 1, # Oxygenated hemoglobin (oxyhemoglobin) concentration + "HbR": 2, # Deoxygenated hemoglobin (deoxyhemoglobin) concentration + "HbT": 3, # Total hemoglobin concentration + "dOD": 4, # Change in optical density + "mua": 5, # Absorption coefficient + "musp": 6, # Scattering coefficient + "H2O": 7, # Water content + "Lipid": 8, # Lipid concentration + "BFi": 9, # Blood flow index + "HRF dOD": 10, # Hemodynamic response function for change in optical density + "HRF HbO": 11, # Hemodynamic response function for oxyhemoglobin concentration + "HRF HbR": 12, # Hemodynamic response function for deoxyhemoglobin concentration + "HRF HbT": 13, # Hemodynamic response function for total hemoglobin concentration + "HRF BFi": 14, # Hemodynamic response function for blood flow index +} \ No newline at end of file diff --git a/mne/io/pick.py b/mne/io/pick.py index d84b0a1a4f1..36c8587358f 100644 --- a/mne/io/pick.py +++ b/mne/io/pick.py @@ -76,6 +76,18 @@ def get_channel_type_constants(include_defaults=False): kind=FIFF.FIFFV_FNIRS_CH, unit=FIFF.FIFF_UNIT_V, coil_type=FIFF.FIFFV_COIL_FNIRS_CW_AMPLITUDE), + fnirs_td_gated_amplitude=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_V, + coil_type=FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE), + fnirs_td_moments_amplitude=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_V, + coil_type=FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE), + fnirs_processed=dict( + kind=FIFF.FIFFV_FNIRS_CH, + unit=FIFF.FIFF_UNIT_V, + coil_type=FIFF.FIFFV_COIL_FNIRS_PROCESSED), fnirs_fd_ac_amplitude=dict( kind=FIFF.FIFFV_FNIRS_CH, unit=FIFF.FIFF_UNIT_V, @@ -160,6 +172,12 @@ def get_channel_type_constants(include_defaults=False): FIFF.FIFFV_COIL_FNIRS_FD_PHASE: 'fnirs_fd_phase', FIFF.FIFFV_COIL_FNIRS_OD: 'fnirs_od', + FIFF.FIFFV_COIL_FNIRS_TD_GATED_AMPLITUDE: + 'fnirs_td_gated_amplitude', + FIFF.FIFFV_COIL_FNIRS_TD_MOMENTS_AMPLITUDE: + 'fnirs_td_moments_amplitude', + FIFF.FIFFV_COIL_FNIRS_PROCESSED: + 'fnirs_processed', }), 'eeg': ('coil_type', {FIFF.FIFFV_COIL_EEG: 'eeg', FIFF.FIFFV_COIL_EEG_BIPOLAR: 'eeg', @@ -966,8 +984,10 @@ def _check_excludes_includes(chs, info=None, allow_bads=False): dbs=True) _PICK_TYPES_KEYS = tuple(list(_PICK_TYPES_DATA_DICT) + ['ref_meg']) _MEG_CH_TYPES_SPLIT = ('mag', 'grad', 'planar1', 'planar2') -_FNIRS_CH_TYPES_SPLIT = ('hbo', 'hbr', 'fnirs_cw_amplitude', - 'fnirs_fd_ac_amplitude', 'fnirs_fd_phase', 'fnirs_od') +_FNIRS_CH_TYPES_SPLIT = ( + 'hbo', 'hbr', 'fnirs_cw_amplitude', 'fnirs_fd_ac_amplitude', + 'fnirs_fd_phase', 'fnirs_od', 'fnirs_td_gated_amplitude', + 'fnirs_td_moments_amplitude', 'fnirs_processed') _DATA_CH_TYPES_ORDER_DEFAULT = ( 'mag', 'grad', 'eeg', 'csd', 'eog', 'ecg', 'resp', 'emg', 'ref_meg', 'misc', 'stim', 'chpi', 'exci', 'ias', 'syst', 'seeg', 'bio', 'ecog', diff --git a/mne/io/snirf/_snirf.py b/mne/io/snirf/_snirf.py index de1070df8ef..82a07f29e45 100644 --- a/mne/io/snirf/_snirf.py +++ b/mne/io/snirf/_snirf.py @@ -12,12 +12,14 @@ from ...annotations import Annotations from ...utils import logger, verbose, fill_doc, warn, _check_fname from ...utils.check import _require_version -from ..constants import FIFF +from ..constants import FIFF, FNIRS_SNIRF_DATATYPELABELS from .._digitization import _make_dig_points from ...transforms import _frame_to_str, apply_trans from ..nirx.nirx import _convert_fnirs_to_head from ..._freesurfer import get_mni_fiducials +AVAILABLE_DATA_TYPES = [1, 201, 301, 99999] + @fill_doc def read_raw_snirf(fname, optode_frame="unknown", preload=False, verbose=None): @@ -98,14 +100,14 @@ def __init__(self, fname, optode_frame="unknown", "MNE does not support this feature. " "Only the first dataset will be processed.") - if np.array(dat.get('nirs/data1/measurementList1/dataType')) != 1: - raise RuntimeError('File does not contain continuous wave ' - 'data. MNE only supports reading continuous' - ' wave amplitude SNIRF files. Expected type' - ' code 1 but received type code %d' % - (np.array(dat.get( - 'nirs/data1/measurementList1/dataType' - )))) + snirf_data_type = np.array( + dat.get('nirs/data1/measurementList1/dataType')) + if snirf_data_type not in AVAILABLE_DATA_TYPES: + raise RuntimeError( + "File does not contain the supported data types. \ + MNE only supports reading the following data types {}, \ + but received type code {}. Processing is only available \ + for data type 1 (CW data).".format(AVAILABLE_DATA_TYPES, snirf_data_type)) last_samps = dat.get('/nirs/data1/dataTimeSeries').shape[0] - 1 @@ -130,6 +132,17 @@ def __init__(self, fname, optode_frame="unknown", fnirs_wavelengths = np.array(dat.get('nirs/probe/wavelengths')) fnirs_wavelengths = [int(w) for w in fnirs_wavelengths] + # Get data type specific probe information + if snirf_data_type == 201: + fnirs_time_delays = np.array( + dat.get('nirs/probe/timeDelays')).tolist() + fnirs_time_delay_widths = np.array( + dat.get('nirs/probe/timeDelayWidths')).tolist() + elif snirf_data_type == 301: + fnirs_moment_orders = np.array( + dat.get('nirs/probe/momentOrders')) + fnirs_moment_orders = [int(m) for m in fnirs_moment_orders] + # Extract channels def atoi(text): return int(text) if text.isdigit() else text @@ -199,26 +212,101 @@ def natural_keys(text): assert len(sources) == srcPos3D.shape[0] assert len(detectors) == detPos3D.shape[0] + # Helper function for when the numpy array has shape (), i.e. just one element. + def _correct_shape(arr): + if arr.shape == (): + arr = arr[np.newaxis] + return arr + chnames = [] for chan in channels: - src_idx = int(np.array(dat.get('nirs/data1/' + - chan + '/sourceIndex'))[0]) - det_idx = int(np.array(dat.get('nirs/data1/' + - chan + '/detectorIndex'))[0]) - wve_idx = int(np.array(dat.get('nirs/data1/' + - chan + '/wavelengthIndex'))[0]) - ch_name = sources[src_idx - 1] + '_' +\ - detectors[det_idx - 1] + ' ' +\ - str(fnirs_wavelengths[wve_idx - 1]) - chnames.append(ch_name) + src_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + '/sourceIndex'))) + [0]) + det_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + '/detectorIndex')))[0]) + if snirf_data_type == 1: + wve_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + + '/wavelengthIndex')))[0]) + ch_name = sources[src_idx - 1] + '_' +\ + detectors[det_idx - 1] + ' ' +\ + str(fnirs_wavelengths[wve_idx - 1]) + chnames.append(ch_name) + elif snirf_data_type == 201: + wve_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + + '/wavelengthIndex')))[0]) + bin_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + + '/dataTypeIndex')))[0]) + ch_name = sources[src_idx - 1] + '_' +\ + detectors[det_idx - 1] + ' ' +\ + str(fnirs_wavelengths[wve_idx - 1]) + ' ' +\ + 'bin' + str(fnirs_time_delays[bin_idx - 1]) + chnames.append(ch_name) + elif snirf_data_type == 301: + wve_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + + '/wavelengthIndex')))[0]) + moment_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + + '/dataTypeIndex')))[0]) + ch_name = sources[src_idx - 1] + '_' +\ + detectors[det_idx - 1] + ' ' +\ + str(fnirs_wavelengths[wve_idx - 1]) + ' ' +\ + 'moment' + str(fnirs_moment_orders[moment_idx - 1]) + chnames.append(ch_name) + elif snirf_data_type == 99999: + hb_id = _correct_shape( + np.array(dat.get('nirs/data1/' + chan + '/dataTypeLabel')))[0].decode('UTF-8') + ch_name = sources[src_idx - 1] + '_' +\ + detectors[det_idx - 1] + ' ' +\ + hb_id + chnames.append(ch_name) # Create mne structure - info = create_info(chnames, - sampling_rate, - ch_types='fnirs_cw_amplitude') + if snirf_data_type == 1: + info = create_info(chnames, + sampling_rate, + ch_types='fnirs_cw_amplitude') + elif snirf_data_type == 201: + info = create_info(chnames, + sampling_rate, + ch_types='fnirs_td_gated_amplitude') + elif snirf_data_type == 301: + info = create_info(chnames, + sampling_rate, + ch_types='fnirs_td_moments_amplitude') + elif snirf_data_type == 99999: + info = create_info(chnames, + sampling_rate, + ch_types='fnirs_processed') subject_info = {} - names = np.array(dat.get('nirs/metaDataTags/SubjectID')) + names = _correct_shape( + np.array(dat.get('nirs/metaDataTags/SubjectID'))) subject_info['first_name'] = names[0].decode('UTF-8') # Read non standard (but allowed) custom metadata tags if 'lastName' in dat.get('nirs/metaDataTags/'): @@ -239,7 +327,8 @@ def natural_keys(text): # Update info info.update(subject_info=subject_info) - LengthUnit = np.array(dat.get('/nirs/metaDataTags/LengthUnit')) + LengthUnit = _correct_shape( + np.array(dat.get('/nirs/metaDataTags/LengthUnit'))) LengthUnit = LengthUnit[0].decode('UTF-8') scal = 1 if "cm" in LengthUnit: @@ -268,21 +357,57 @@ def natural_keys(text): coord_frame = FIFF.FIFFV_COORD_UNKNOWN for idx, chan in enumerate(channels): - src_idx = int(np.array(dat.get('nirs/data1/' + - chan + '/sourceIndex'))[0]) - det_idx = int(np.array(dat.get('nirs/data1/' + - chan + '/detectorIndex'))[0]) - wve_idx = int(np.array(dat.get('nirs/data1/' + - chan + '/wavelengthIndex'))[0]) + src_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + '/sourceIndex'))) + [0]) + det_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + '/detectorIndex')))[0]) info['chs'][idx]['loc'][3:6] = srcPos3D[src_idx - 1, :] info['chs'][idx]['loc'][6:9] = detPos3D[det_idx - 1, :] # Store channel as mid point midpoint = (info['chs'][idx]['loc'][3:6] + info['chs'][idx]['loc'][6:9]) / 2 info['chs'][idx]['loc'][0:3] = midpoint - info['chs'][idx]['loc'][9] = fnirs_wavelengths[wve_idx - 1] info['chs'][idx]['coord_frame'] = coord_frame + # get data type specific info: + if snirf_data_type in [1, 201, 301]: + wve_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + + '/wavelengthIndex')))[0]) + info['chs'][idx]['loc'][9] = fnirs_wavelengths[wve_idx - 1] + elif snirf_data_type == 99999: + hb_id = _correct_shape( + np.array(dat.get('nirs/data1/' + chan + '/dataTypeLabel')))[0].decode('UTF-8') + info['chs'][idx]['loc'][9] = FNIRS_SNIRF_DATATYPELABELS[hb_id] + + if snirf_data_type == 201: + bin_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + + '/dataTypeIndex')))[0]) + info['chs'][idx]['loc'][10] = fnirs_time_delays[bin_idx - + 1] * fnirs_time_delay_widths + elif snirf_data_type == 301: + moment_idx = int( + _correct_shape( + np.array( + dat.get( + 'nirs/data1/' + chan + + '/dataTypeIndex')))[0]) + info['chs'][idx]['loc'][10] = fnirs_moment_orders[moment_idx - 1] + if 'landmarkPos3D' in dat.get('nirs/probe/'): diglocs = np.array(dat.get('/nirs/probe/landmarkPos3D')) digname = np.array(dat.get('/nirs/probe/landmarkLabels')) @@ -320,39 +445,41 @@ def natural_keys(text): info['dig'] = _format_dig_points(dig) del head_t - str_date = np.array((dat.get( - '/nirs/metaDataTags/MeasurementDate')))[0].decode('UTF-8') - str_time = np.array((dat.get( - '/nirs/metaDataTags/MeasurementTime')))[0].decode('UTF-8') - str_datetime = str_date + str_time - - # Several formats have been observed so we try each in turn - for dt_code in ['%Y-%m-%d%H:%M:%SZ', - '%Y-%m-%d%H:%M:%S']: - try: - meas_date = datetime.datetime.strptime( - str_datetime, dt_code) - except ValueError: - pass + str_date = _correct_shape(np.array((dat.get( + '/nirs/metaDataTags/MeasurementDate'))))[0].decode('UTF-8') + str_time = _correct_shape(np.array((dat.get( + '/nirs/metaDataTags/MeasurementTime'))))[0].decode('UTF-8') + str_datetime = str_date + str_time + + # Several formats have been observed so we try each in turn + for dt_code in ['%Y-%m-%d%H:%M:%SZ', + '%Y-%m-%d%H:%M:%S']: + try: + meas_date = datetime.datetime.strptime( + str_datetime, dt_code) + except ValueError: + pass + else: + break else: - break - else: - warn("Extraction of measurement date from SNIRF file failed. " - "The date is being set to January 1st, 2000, " - f"instead of {str_datetime}") - meas_date = datetime.datetime(2000, 1, 1, 0, 0, 0) - meas_date = meas_date.replace(tzinfo=datetime.timezone.utc) - info['meas_date'] = meas_date - - if 'DateOfBirth' in dat.get('nirs/metaDataTags/'): - str_birth = np.array((dat.get('/nirs/metaDataTags/' - 'DateOfBirth')))[0].decode() - birth_matched = re.fullmatch(r'(\d+)-(\d+)-(\d+)', str_birth) - if birth_matched is not None: - info["subject_info"]['birthday'] = ( - int(birth_matched.groups()[0]), - int(birth_matched.groups()[1]), - int(birth_matched.groups()[2])) + warn( + "Extraction of measurement date from SNIRF file failed. " + "The date is being set to January 1st, 2000, " + f"instead of {str_datetime}") + meas_date = datetime.datetime(2000, 1, 1, 0, 0, 0) + meas_date = meas_date.replace(tzinfo=datetime.timezone.utc) + info['meas_date'] = meas_date + + if 'DateOfBirth' in dat.get('nirs/metaDataTags/'): + str_birth = _correct_shape( + np.array((dat.get('/nirs/metaDataTags/' 'DateOfBirth'))))[0].decode() + birth_matched = re.fullmatch( + r'(\d+)-(\d+)-(\d+)', str_birth) + if birth_matched is not None: + info["subject_info"]['birthday'] = ( + int(birth_matched.groups()[0]), + int(birth_matched.groups()[1]), + int(birth_matched.groups()[2])) super(RawSNIRF, self).__init__(info, preload, filenames=[fname], last_samps=[last_samps], @@ -365,17 +492,19 @@ def natural_keys(text): data = np.atleast_2d(np.array( dat.get('/nirs/' + key + '/data'))) if data.size > 0: - desc = dat.get('/nirs/' + key + '/name')[0] + desc = _correct_shape( + np.array(dat.get('/nirs/' + key + '/name')))[0] annot.append(data[:, 0], 1.0, desc.decode('UTF-8')) self.set_annotations(annot) # Reorder channels to match expected ordering in MNE - num_chans = len(self.ch_names) - chans = [] - for idx in range(num_chans // 2): - chans.append(idx) - chans.append(idx + num_chans // 2) - self.pick(picks=chans) + if snirf_data_type in [1, 99999]: + num_chans = len(self.ch_names) + chans = [] + for idx in range(num_chans // 2): + chans.append(idx) + chans.append(idx + num_chans // 2) + self.pick(picks=chans) # Validate that the fNIRS info is correctly formatted _validate_nirs_info(self.info)