diff --git a/neo/rawio/blackrockrawio.py b/neo/rawio/blackrockrawio.py index fddd5386b..0b41da99f 100644 --- a/neo/rawio/blackrockrawio.py +++ b/neo/rawio/blackrockrawio.py @@ -61,6 +61,7 @@ """ import datetime +import mmap import os import re import warnings @@ -325,20 +326,6 @@ def _parse_header(self): sampling_rate = 30_000.0 / nsx_period self._nsx_sampling_frequency[nsx_nb] = float(sampling_rate) - # Parase data packages - for nsx_nb in self._avail_nsx: - - # The only way to know if it is the Precision Time Protocol of file spec 3.0 - # is to check for nanosecond timestamp resolution. - is_ptp_variant = ( - "timestamp_resolution" in self._nsx_basic_header[nsx_nb].dtype.names - and self._nsx_basic_header[nsx_nb]["timestamp_resolution"] == 1_000_000_000 - ) - if is_ptp_variant: - data_header_spec = "3.0-ptp" - else: - data_header_spec = spec_version - # nsx_to_load can be either int, list, 'max', 'all' (aka None) # here make a list only if self.nsx_to_load is None or self.nsx_to_load == "all": @@ -375,7 +362,6 @@ def _parse_header(self): # Remove if raw loading becomes possible # raise IOError("For loading Blackrock file version 2.1 .nev files are required!") - self.nsx_datas = {} # Keep public attribute for backward compatibility but let's use the private one and maybe deprecate this at some point self.sig_sampling_rates = { nsx_number: self._nsx_sampling_frequency[nsx_number] for nsx_number in self.nsx_to_load @@ -395,17 +381,11 @@ def _parse_header(self): else: data_spec = spec_version - # Parse data blocks (creates memmap, extracts data+timestamps) - data_blocks = self._parse_nsx_data(data_spec, nsx_nb) + # Parse data blocks (extracts offsets, sample counts, timestamps) + parsed_data_headers = self._parse_nsx_data(data_spec, nsx_nb) - # Segment the data (analyzes gaps, reports issues) - segments = self._segment_nsx_data(data_blocks, nsx_nb) - - # Store in existing structures for backward compatibility - self._nsx_data_header[nsx_nb] = { - seg_idx: {k: v for k, v in seg.items() if k != "data"} for seg_idx, seg in segments.items() - } - self.nsx_datas[nsx_nb] = {seg_idx: seg["data"] for seg_idx, seg in segments.items()} + # Segment the data (analyzes gaps, creates per-segment metadata) + self._nsx_data_header[nsx_nb] = self._segment_nsx_data(parsed_data_headers, nsx_nb) # Match NSX and NEV segments for v2.3 if self._avail_files["nev"]: @@ -449,7 +429,7 @@ def _parse_header(self): signal_channels.append((ch_name, ch_id, sr, sig_dtype, units, gain, offset, stream_id, buffer_id)) # check nb segment per nsx - nb_segments_for_nsx = [len(self.nsx_datas[nsx_nb]) for nsx_nb in self.nsx_to_load] + nb_segments_for_nsx = [len(self._nsx_data_header[nsx_nb]) for nsx_nb in self.nsx_to_load] if not all(nb == nb_segments_for_nsx[0] for nb in nb_segments_for_nsx): raise NeoReadWriteError("Segment nb not consistent across nsX files") self._nb_segment = nb_segments_for_nsx[0] @@ -471,19 +451,31 @@ def _parse_header(self): ts_res = 30_000 period = self._nsx_basic_header[nsx_nb]["period"] sec_per_samp = period / 30_000 # Maybe 30_000 should be ['sample_resolution'] - length = self.nsx_datas[nsx_nb][data_bl].shape[0] - timestamps = self._nsx_data_header[nsx_nb][data_bl]["timestamp"] - if timestamps is None: + seg_header = self._nsx_data_header[nsx_nb][data_bl] + length = seg_header["nb_data_points"] + timestamp = seg_header["timestamp"] + + if "timestamps_memmap_kwargs" in seg_header: + # FileSpec 3.0 with PTP -- read first/last timestamps on demand + ts_kw = seg_header["timestamps_memmap_kwargs"] + fid = self._get_nsx_fid(nsx_nb) + ts_array = self._create_mmap_view( + fid=fid, + dtype=ts_kw["dtype"], + offset=ts_kw["offset"], + num_samples=ts_kw["num_samples"], + packet_size=ts_kw.get("packet_size"), + item_offset=ts_kw.get("item_offset", 0), + ) + t_start = float(ts_array[0]) / ts_res + t_stop = max(t_stop, float(ts_array[-1]) / ts_res + sec_per_samp) + elif timestamp is None: # V2.1 format has no timestamps t_start = 0.0 t_stop = max(t_stop, length / self._nsx_sampling_frequency[nsx_nb]) - elif hasattr(timestamps, "size") and timestamps.size == length: - # FileSpec 3.0 with PTP -- use the per-sample timestamps - t_start = timestamps[0] / ts_res - t_stop = max(t_stop, timestamps[-1] / ts_res + sec_per_samp) else: # Standard format with scalar timestamp - t_start = timestamps / ts_res + t_start = timestamp / ts_res t_stop = max(t_stop, t_start + length / self._nsx_sampling_frequency[nsx_nb]) self._sigs_t_starts[nsx_nb].append(t_start) @@ -650,22 +642,117 @@ def _segment_t_stop(self, block_index, seg_index): def _get_signal_size(self, block_index, seg_index, stream_index): stream_id = self.header["signal_streams"][stream_index]["id"] nsx_nb = int(stream_id) - memmap_data = self.nsx_datas[nsx_nb][seg_index] - return memmap_data.shape[0] + return self._nsx_data_header[nsx_nb][seg_index]["nb_data_points"] def _get_signal_t_start(self, block_index, seg_index, stream_index): stream_id = self.header["signal_streams"][stream_index]["id"] nsx_nb = int(stream_id) return self._sigs_t_starts[nsx_nb][seg_index] + @staticmethod + def _create_mmap_view(fid, dtype, offset, num_samples, num_channels=None, + packet_size=None, item_offset=0): + """ + Create an np.ndarray view over a raw mmap buffer from an open file. + + When packet_size is None, creates a standard contiguous view (for + standard/v2.1 formats where samples are stored contiguously as int16). + + When packet_size is provided, creates a strided view that extracts + interleaved fields from PTP packets. Each packet has a fixed size, + and the target field starts at item_offset bytes into the packet. + The stride between rows equals packet_size, allowing the view to + skip over other fields (timestamps, reserved bytes, etc.). + + Parameters + ---------- + fid : file-like + Open file object (must support .fileno()). + dtype : str or np.dtype + Data type of the target field (e.g. "int16" for samples, + "uint64" for timestamps). + offset : int + Byte offset in the file where the data region starts. + num_samples : int + Number of samples (rows) to read. + num_channels : int or None + Number of channels (columns). None for 1D arrays (e.g. timestamps). + packet_size : int or None + Stride between consecutive rows in bytes. None for contiguous data. + item_offset : int + Byte offset of the target field within each packet. + + Returns + ------- + np.ndarray + View into the memory-mapped file. Shape is (num_samples, num_channels) + when num_channels is provided, or (num_samples,) otherwise. + """ + dtype = np.dtype(dtype) + + if num_channels is not None: + shape = (num_samples, num_channels) + else: + shape = (num_samples,) + + if packet_size is None: + bytes_per_sample = dtype.itemsize * (num_channels if num_channels is not None else 1) + start_byte = offset + length = num_samples * bytes_per_sample + else: + start_byte = offset + item_offset + length = (num_samples - 1) * packet_size + ( + dtype.itemsize * num_channels if num_channels is not None else dtype.itemsize + ) + + # mmap offset must be aligned to ALLOCATIONGRANULARITY + mmap_offset, start_remainder = divmod(start_byte, mmap.ALLOCATIONGRANULARITY) + mmap_offset *= mmap.ALLOCATIONGRANULARITY + length += start_remainder + + raw_mmap = mmap.mmap(fid.fileno(), length=length, access=mmap.ACCESS_READ, offset=mmap_offset) + + if packet_size is not None: + strides = (packet_size, dtype.itemsize) if num_channels is not None else (packet_size,) + else: + strides = None # default contiguous strides + + return np.ndarray( + shape=shape, + dtype=dtype, + buffer=raw_mmap, + offset=start_remainder, + strides=strides, + ) + + def _get_nsx_fid(self, nsx_nb): + """Open NSX file on demand and cache the file descriptor for reuse.""" + if not hasattr(self, "_nsx_fids"): + self._nsx_fids = {} + if nsx_nb not in self._nsx_fids: + filename = f"{self._filenames['nsx']}.ns{nsx_nb}" + self._nsx_fids[nsx_nb] = open(filename, "rb") + return self._nsx_fids[nsx_nb] + def _get_analogsignal_chunk(self, block_index, seg_index, i_start, i_stop, stream_index, channel_indexes): stream_id = self.header["signal_streams"][stream_index]["id"] nsx_nb = int(stream_id) - memmap_data = self.nsx_datas[nsx_nb][seg_index] + seg = self._nsx_data_header[nsx_nb][seg_index] + fid = self._get_nsx_fid(nsx_nb) + kw = seg["memmap_kwargs"] + channels = int(self._nsx_basic_header[nsx_nb]["channel_count"]) + data = self._create_mmap_view( + fid=fid, + dtype=kw["dtype"], + offset=kw["offset"], + num_samples=kw["num_samples"], + num_channels=kw.get("num_channels", channels), + packet_size=kw.get("packet_size"), + item_offset=kw.get("item_offset", 0), + ) if channel_indexes is None: channel_indexes = slice(None) - sig_chunk = memmap_data[i_start:i_stop, channel_indexes] - return sig_chunk + return data[i_start:i_stop, channel_indexes] def _get_blackrock_timestamps(self, block_index, seg_index, i_start, i_stop, stream_index): """ @@ -710,19 +797,27 @@ def _get_blackrock_timestamps(self, block_index, seg_index, i_start, i_stop, str """ stream_id = self.header["signal_streams"][stream_index]["id"] nsx_nb = int(stream_id) + seg = self._nsx_data_header[nsx_nb][seg_index] # Resolve None to concrete indices - size = self.nsx_datas[nsx_nb][seg_index].shape[0] + size = seg["nb_data_points"] i_start = i_start if i_start is not None else 0 i_stop = i_stop if i_stop is not None else size - # Check if this segment has per-sample timestamps (PTP format) - raw_timestamps = self._nsx_data_header[nsx_nb][seg_index]["timestamp"] - - if isinstance(raw_timestamps, np.ndarray) and raw_timestamps.size == size: - # PTP: real hardware timestamps + # PTP format: read timestamps on demand via strided mmap view + if "timestamps_memmap_kwargs" in seg: + ts_kw = seg["timestamps_memmap_kwargs"] + fid = self._get_nsx_fid(nsx_nb) + timestamps = self._create_mmap_view( + fid=fid, + dtype=ts_kw["dtype"], + offset=ts_kw["offset"], + num_samples=ts_kw["num_samples"], + packet_size=ts_kw.get("packet_size"), + item_offset=ts_kw.get("item_offset", 0), + ) ts_res = float(self._nsx_basic_header[nsx_nb]["timestamp_resolution"]) - return raw_timestamps[i_start:i_stop].astype("float64") / ts_res + return timestamps[i_start:i_stop].astype("float64") / ts_res else: # Non-PTP: reconstruct from t_start + index / sampling_rate t_start = self._sigs_t_starts[nsx_nb][seg_index] @@ -1073,28 +1168,27 @@ def _parse_nsx_data(self, spec, nsx_nb): Returns ------- - dict - Dictionary mapping block index to block information: - { - block_idx: { - "data": np.ndarray, - View into memory-mapped file with shape (samples, channels) - "timestamps": scalar, np.ndarray, or None, + list[dict] + List of parsed data block headers, one per block in the file: + [ + { + "timestamps": scalar or None, - Standard format: scalar (one timestamp per block) - - PTP format: array (one timestamp per sample) - v2.1 format: None (no timestamps) - # Additional metadata as needed + "memmap_kwargs": dict, + Recipe for on-demand mmap creation (offset, num_samples, dtype, etc.) + "ptp_timestamps_memmap_kwargs": dict (PTP only), + Strided access recipe for per-sample PTP timestamps }, ... - } + ] Notes ----- - - This function creates the file memmap internally - - Data views are created using np.ndarray with buffer parameter (memory efficient) + - No memmaps are stored; only lightweight metadata for on-demand creation - Returned data is NOT YET SEGMENTED (segmentation happens in a separate step) - - For standard format, each block from the file is one dict entry - - For PTP format, all data is in a single block (block_idx=0) + - For standard format, each block from the file is one list entry + - For PTP format, all data is in a single entry """ if spec == "2.1": return self._parse_nsx_data_v21(nsx_nb) @@ -1114,14 +1208,11 @@ def _parse_nsx_data_v21(self, nsx_nb): Returns ------- - dict - {0: {"data": np.ndarray, "timestamps": None}} + list[dict] + Single-element list: [{"timestamps": None, "memmap_kwargs": dict}] """ filename = f"{self._filenames['nsx']}.ns{nsx_nb}" - # Create file memmap - file_memmap = np.memmap(filename, dtype="uint8", mode="r") - # Calculate header size and data points for v2.1 channels = int(self._nsx_basic_header[nsx_nb]["channel_count"]) bytes_in_headers = ( @@ -1129,16 +1220,18 @@ def _parse_nsx_data_v21(self, nsx_nb): ) filesize = self._get_file_size(filename) num_samples = int((filesize - bytes_in_headers) / (2 * channels) - 1) - offset = bytes_in_headers - # Create data view into memmap - data = np.ndarray(shape=(num_samples, channels), dtype="int16", buffer=file_memmap, offset=offset) - return { - 0: { - "data": data, + return [ + { "timestamps": None, + "memmap_kwargs": { + "filename": filename, + "dtype": "int16", + "offset": bytes_in_headers, + "num_samples": num_samples, + }, } - } + ] def _parse_nsx_data_v22_v30(self, spec, nsx_nb): """ @@ -1151,21 +1244,17 @@ def _parse_nsx_data_v22_v30(self, spec, nsx_nb): Returns ------- - dict - {block_idx: {"data": np.ndarray, "timestamps": scalar}, ...} + list[dict] + [{"timestamps": scalar, "memmap_kwargs": dict}, ...] """ filename = f"{self._filenames['nsx']}.ns{nsx_nb}" - # Create file memmap - file_memmap = np.memmap(filename, dtype="uint8", mode="r") - # Get file parameters filesize = self._get_file_size(filename) channels = int(self._nsx_basic_header[nsx_nb]["channel_count"]) current_offset = int(self._nsx_basic_header[nsx_nb]["bytes_in_headers"]) - data_blocks = {} - block_idx = 0 + parsed_data_headers = [] # Loop through file, reading block headers while current_offset < filesize: @@ -1182,20 +1271,21 @@ def _parse_nsx_data_v22_v30(self, spec, nsx_nb): data_offset = current_offset + header.dtype.itemsize timestamp = header["timestamp"] - # Create data view into memmap for this block - data = np.ndarray(shape=(num_samples, channels), dtype="int16", buffer=file_memmap, offset=data_offset) - - data_blocks[block_idx] = { - "data": data, + parsed_data_headers.append({ "timestamps": timestamp, - } + "memmap_kwargs": { + "filename": filename, + "dtype": "int16", + "offset": data_offset, + "num_samples": num_samples, + }, + }) # Jump to next block data_size_bytes = num_samples * channels * 2 # int16 = 2 bytes current_offset = data_offset + data_size_bytes - block_idx += 1 - return data_blocks + return parsed_data_headers def _parse_nsx_data_v30_ptp(self, nsx_nb): """ @@ -1208,8 +1298,9 @@ def _parse_nsx_data_v30_ptp(self, nsx_nb): Returns ------- - dict - {0: {"data": np.ndarray, "timestamps": np.ndarray}} + list[dict] + Single-element list: [{"timestamps": None, "memmap_kwargs": dict, + "ptp_timestamps_memmap_kwargs": dict}] """ filename = f"{self._filenames['nsx']}.ns{nsx_nb}" @@ -1218,26 +1309,47 @@ def _parse_nsx_data_v30_ptp(self, nsx_nb): header_size = int(self._nsx_basic_header[nsx_nb]["bytes_in_headers"]) channel_count = int(self._nsx_basic_header[nsx_nb]["channel_count"]) - # Create structured memmap (timestamp + samples per packet) + # Create structured memmap for verification only ptp_dt = NSX_DATA_HEADER_TYPES["3.0-ptp"](channel_count) - npackets = int((filesize - header_size) / np.dtype(ptp_dt).itemsize) - file_memmap = np.memmap(filename, dtype=ptp_dt, shape=npackets, offset=header_size, mode="r") + ptp_dtype = np.dtype(ptp_dt) + npackets = int((filesize - header_size) / ptp_dtype.itemsize) + temp_memmap = np.memmap(filename, dtype=ptp_dt, shape=npackets, offset=header_size, mode="r") # Verify this is truly PTP (all packets should have 1 sample) - if not np.all(file_memmap["num_data_points"] == 1): + if not np.all(temp_memmap["num_data_points"] == 1): # Not actually PTP! Fall back to standard format + del temp_memmap return self._parse_nsx_data_v22_v30("3.0", nsx_nb) - # Extract data and timestamps from structured array - data = file_memmap["samples"] - timestamps = file_memmap["timestamps"] + del temp_memmap - return { - 0: { - "data": data, - "timestamps": timestamps, + # Compute strided access parameters from the structured dtype + packet_size = ptp_dtype.itemsize + samples_item_offset = ptp_dtype.fields["samples"][1] + timestamps_item_offset = ptp_dtype.fields["timestamps"][1] + + return [ + { + "timestamps": None, + "memmap_kwargs": { + "filename": filename, + "dtype": "int16", + "offset": header_size, + "num_samples": npackets, + "num_channels": channel_count, + "packet_size": packet_size, + "item_offset": samples_item_offset, + }, + "ptp_timestamps_memmap_kwargs": { + "filename": filename, + "dtype": "uint64", + "offset": header_size, + "num_samples": npackets, + "packet_size": packet_size, + "item_offset": timestamps_item_offset, + }, } - } + ] def _format_gap_report(self, gap_indices, timestamps_in_seconds, time_differences, nsx_nb): """ @@ -1281,11 +1393,11 @@ def _format_gap_report(self, gap_indices, timestamps_in_seconds, time_difference + "+-----------------+-----------------------+-----------------------+\n" ) - def _segment_nsx_data(self, data_blocks_dict, nsx_nb): + def _segment_nsx_data(self, parsed_data_headers, nsx_nb): """ Segment NSX data based on timestamp gaps. - Takes the data blocks returned by _parse_nsx_data() and creates segments. + Takes the parsed data headers returned by _parse_nsx_data() and creates segments. Segmentation logic depends on the file format: - Standard format (multiple blocks): Each block IS a segment @@ -1294,47 +1406,60 @@ def _segment_nsx_data(self, data_blocks_dict, nsx_nb): Parameters ---------- - data_blocks_dict : dict - Dictionary from _parse_nsx_data(): - {block_idx: {"data": np.ndarray, "timestamps": scalar/array/None}} + parsed_data_headers : list[dict] + List from _parse_nsx_data(): + [{"timestamps": scalar/None, "memmap_kwargs": dict, ...}, ...] nsx_nb : int NSX file number Returns ------- - dict - { - seg_idx: { - "data": np.ndarray, - "timestamps": scalar, array, or None, + list[dict] + [ + { + "timestamp": scalar or None, "nb_data_points": int, "header": int or None, - "offset_to_data_block": None (deprecated but kept for compatibility) + "offset_to_data_block": None (deprecated but kept for compatibility), + "memmap_kwargs": dict, + "timestamps_memmap_kwargs": dict (PTP only), }, ... - } + ] """ - segments = {} + segments = [] # Case 1: Multiple blocks (Standard format) - each block is a segment - if len(data_blocks_dict) > 1: - for block_idx, block_info in data_blocks_dict.items(): - segments[block_idx] = { - "data": block_info["data"], + if len(parsed_data_headers) > 1: + for block_info in parsed_data_headers: + segments.append({ "timestamp": block_info["timestamps"], # Use singular for backward compatibility - "nb_data_points": block_info["data"].shape[0], + "nb_data_points": block_info["memmap_kwargs"]["num_samples"], "header": 1, # Standard format has headers - "offset_to_data_block": None, # Not needed (have data directly) - } - - # Case 2: Single block - check if PTP (array timestamps) or simple (no timestamps) - elif len(data_blocks_dict) == 1: - block_info = data_blocks_dict[0] - data = block_info["data"] - timestamps = block_info["timestamps"] + "offset_to_data_block": None, # Not needed + "memmap_kwargs": block_info["memmap_kwargs"], + }) + + # Case 2: Single block - check if PTP (has ptp_timestamps_memmap_kwargs) or simple + elif len(parsed_data_headers) == 1: + block_info = parsed_data_headers[0] + + # PTP format: read timestamps on demand and detect gaps + if "ptp_timestamps_memmap_kwargs" in block_info: + ts_kw = block_info["ptp_timestamps_memmap_kwargs"] + samples_kw = block_info["memmap_kwargs"] + + # Read timestamps via strided mmap view for gap detection + fid = self._get_nsx_fid(nsx_nb) + timestamps = self._create_mmap_view( + fid=fid, + dtype=ts_kw["dtype"], + offset=ts_kw["offset"], + num_samples=ts_kw["num_samples"], + packet_size=ts_kw.get("packet_size"), + item_offset=ts_kw.get("item_offset", 0), + ) - # PTP format: array of timestamps - need to detect gaps - if isinstance(timestamps, np.ndarray): # Analyze timestamp gaps sampling_rate = self._nsx_sampling_frequency[nsx_nb] @@ -1369,29 +1494,48 @@ def _segment_nsx_data(self, data_blocks_dict, nsx_nb): gap_indices = significant_gap_indices # Create segments based on gaps + num_total_samples = ts_kw["num_samples"] segment_starts = np.hstack((0, gap_indices + 1)) - segment_boundaries = list(segment_starts) + [len(data)] + segment_boundaries = list(segment_starts) + [num_total_samples] - for seg_idx, start in enumerate(segment_starts): - end = segment_boundaries[seg_idx + 1] + packet_size = samples_kw["packet_size"] + base_samples_offset = samples_kw["offset"] + base_ts_offset = ts_kw["offset"] - segments[seg_idx] = { - "data": data[start:end], - "timestamp": timestamps[start:end], # Use singular for backward compatibility - "nb_data_points": end - start, + for seg_index, start in enumerate(segment_starts): + end = segment_boundaries[seg_index + 1] + seg_num_samples = end - start + + # Compute new file offset for this segment slice + seg_samples_offset = base_samples_offset + int(start) * packet_size + seg_ts_offset = base_ts_offset + int(start) * packet_size + + segments.append({ + "timestamp": None, # PTP timestamps read on demand + "nb_data_points": seg_num_samples, "header": None, # PTP has no headers "offset_to_data_block": None, - } + "memmap_kwargs": { + **samples_kw, + "offset": seg_samples_offset, + "num_samples": seg_num_samples, + }, + "timestamps_memmap_kwargs": { + **ts_kw, + "offset": seg_ts_offset, + "num_samples": seg_num_samples, + }, + }) # V2.1 or single block standard format: no segmentation needed else: - segments[0] = { - "data": data, - "timestamp": timestamps, # Use singular for backward compatibility - "nb_data_points": data.shape[0], + segments.append({ + "timestamp": block_info["timestamps"], # Use singular for backward compatibility + "nb_data_points": block_info["memmap_kwargs"]["num_samples"], "header": None, "offset_to_data_block": None, - } + "memmap_kwargs": block_info["memmap_kwargs"], + }) return segments @@ -1670,18 +1814,18 @@ def _match_nsx_and_nev_segment_ids(self, nsx_nb): nsx_offset = self._nsx_data_header[nsx_nb][0]["timestamp"] # Multiples of 1/30.000s that pass between two nsX samples nsx_period = self._nsx_basic_header[nsx_nb]["period"] - # NSX segments needed as dict and list - nonempty_nsx_segments = {} - list_nonempty_nsx_segments = [] # Counts how many segments CAN be created from nev nb_possible_nev_segments = self._nb_segment_nev # Nonempty segments are those containing at least 2 samples # These have to be able to be mapped to nev - for k, v in sorted(self._nsx_data_header[nsx_nb].items()): - if v["nb_data_points"] > 1: - nonempty_nsx_segments[k] = v - list_nonempty_nsx_segments.append(v) + nonempty_nsx_segment_indices = [ + seg_index for seg_index, seg in enumerate(self._nsx_data_header[nsx_nb]) + if seg["nb_data_points"] > 1 + ] + nonempty_nsx_segments = [ + self._nsx_data_header[nsx_nb][seg_index] for seg_index in nonempty_nsx_segment_indices + ] # Account for paused segments # This increases nev event segment ids if from the nsx an additional segment is found @@ -1690,7 +1834,7 @@ def _match_nsx_and_nev_segment_ids(self, nsx_nb): for k, (data, ev_ids) in self.nev_data.items(): # Check all nonempty nsX segments - for i, seg in enumerate(list_nonempty_nsx_segments[:]): + for i, seg in enumerate(nonempty_nsx_segments[:]): # Last timestamp in this nsX segment # Not subtracting nsX offset from end because spike extraction might continue @@ -1718,7 +1862,7 @@ def _match_nsx_and_nev_segment_ids(self, nsx_nb): # because a new one has been discovered if len(data[mask_after_seg]) > 0: # Warning if spikes are after last segment - if i == len(list_nonempty_nsx_segments) - 1: + if i == len(nonempty_nsx_segments) - 1: # Get timestamp resolution from header (available for v2.2+) timestamp_resolution = self._nsx_basic_header[nsx_nb]["timestamp_resolution"] time_after_seg = ( @@ -1732,7 +1876,7 @@ def _match_nsx_and_nev_segment_ids(self, nsx_nb): # If reset and no segment detected in nev, then these segments cannot be # distinguished in nev, which is a big problem # XXX 96 is an arbitrary number based on observations in available files - elif list_nonempty_nsx_segments[i + 1]["timestamp"] - nsx_offset <= 96: + elif nonempty_nsx_segments[i + 1]["timestamp"] - nsx_offset <= 96: # If not all definitely belong to the next segment, # then it cannot be distinguished where some belong if len(data[ev_ids == i]) != len(data[mask_after_seg]): @@ -1751,7 +1895,7 @@ def _match_nsx_and_nev_segment_ids(self, nsx_nb): f"ns{nsx_nb} file." ) - new_nev_segment_id_mapping = dict(zip(range(nb_possible_nev_segments), sorted(list(nonempty_nsx_segments)))) + new_nev_segment_id_mapping = dict(zip(range(nb_possible_nev_segments), nonempty_nsx_segment_indices)) # replacing event ids by matched event ids in place for k, (data, ev_ids) in self.nev_data.items(): @@ -1996,36 +2140,25 @@ def _delete_empty_segments(self): segment in the nsX data. """ - # Discard empty segments - removed_seg = [] - for data_bl in range(self._nb_segment): + # Find empty segments (fewer than 2 samples across all nsx files) + empty_indices = [] + for seg_index in range(self._nb_segment): keep_seg = True for nsx_nb in self.nsx_to_load: - length = self.nsx_datas[nsx_nb][data_bl].shape[0] + length = self._nsx_data_header[nsx_nb][seg_index]["nb_data_points"] keep_seg = keep_seg and (length >= 2) - if not keep_seg: - removed_seg.append(data_bl) - for nsx_nb in self.nsx_to_load: - self.nsx_datas[nsx_nb].pop(data_bl) - self._nsx_data_header[nsx_nb].pop(data_bl) - - # Keys need to be increasing from 0 to maximum in steps of 1 - # To ensure this after removing empty segments, some keys need to be re mapped - for i in removed_seg[::-1]: - for j in range(i + 1, self._nb_segment): - # remap nsx seg index - for nsx_nb in self.nsx_to_load: - data = self.nsx_datas[nsx_nb].pop(j) - self.nsx_datas[nsx_nb][j - 1] = data + empty_indices.append(seg_index) - data_header = self._nsx_data_header[nsx_nb].pop(j) - self._nsx_data_header[nsx_nb][j - 1] = data_header + # Remove empty segments in reverse order to preserve indices + for seg_index in reversed(empty_indices): + for nsx_nb in self.nsx_to_load: + del self._nsx_data_header[nsx_nb][seg_index] - # Also remap nev data, ev_ids are the equivalent to keys above - if self._avail_files["nev"]: - for k, (data, ev_ids) in self.nev_data.items(): - ev_ids[ev_ids == j] -= 1 + # Remap nev segment ids: shift down all ids above the removed segment + if self._avail_files["nev"]: + for _key, (data, ev_ids) in self.nev_data.items(): + ev_ids[ev_ids > seg_index] -= 1 self._nb_segment -= 1