Source code for turtlewave_hdEEG.dataset

import numpy as np
import json
import scipy.io
import h5py
from pathlib import Path
from datetime import datetime
from wonambi import Dataset as WonambiDataset



[docs] class LargeDataset: """Dataset class optimized for large EEG recordings""" def __init__(self, filename, create_memmap=False, memmap_dir=None, extract_eeglab_metadata=True): """ Initialize a large dataset handler Parameters ---------- filename : str Path to the original EEG file create_memmap : bool Whether to create a memory-mapped version of the data memmap_dir : str or None Directory to store memory-mapped files, if None use same directory as input """ self.filename = Path(filename) self.original_dataset = WonambiDataset(filename) self.memmap_info = None # Copy basic header info self.header = self.original_dataset.header self.channels = self.header['chan_name'] self.sampling_rate = self.header['s_freq'] # Fix EEGLAB start time if needed if extract_eeglab_metadata: self._extract_eeglab_metadata() # Create memory map if requested if create_memmap: self.create_memmap(memmap_dir) def _extract_eeglab_metadata(self): """ Extract metadata from an EEGLAB .mat file. """ """Extract metadata with failsafe timeout""" result = [None] error = [None] try: # Attempt to load the file using scipy.io.loadmat eeglab_data = scipy.io.loadmat(self.filename, struct_as_record=False, squeeze_me=True) is_h5py = False except NotImplementedError: # Handle MATLAB v7.3 files using h5py print("MATLAB v7.3 file detected. Using h5py to load the file.") with h5py.File(self.filename, 'r') as f: eeglab_data = self._load_hdf5_data(f) is_h5py = True try: if is_h5py: # Handle h5py data structure self._process_h5py_metadata(eeglab_data) else: # Handle scipy data structure self._process_scipy_metadata(eeglab_data) result[0] = eeglab_data except Exception as e: print(f"Error extracting EEGLAB metadata: {e}") error[0] = e return result[0] def _process_scipy_metadata(self, eeglab_data): """Process metadata from scipy.io.loadmat structure""" # Access the EEG structure eeg = eeglab_data.get('EEG', None) if eeg is None: print("Warning: Could not find EEG structure in the EEGLAB file") return # Extract additional metadata from the EEG structure try: for attr in ['group', 'condition', 'session']: if hasattr(eeg, attr): self.header[attr] = getattr(eeg, attr) if hasattr(eeg, 'etc') and hasattr(eeg.etc, 'stages'): self.header['stages'] = eeg.etc.stages if hasattr(eeg, 'event'): # Extract event information event_onsets = [getattr(event, 'latency', None) for event in eeg.event] # sample points event_types = [getattr(event, 'type', None) for event in eeg.event] event_durations = [getattr(event, 'duration', None) for event in eeg.event] event_isreject = [getattr(event, 'is_reject', None) for event in eeg.event] # Create annotations object annotations = { 'onsets': event_onsets, 'types': event_types, 'durations': event_durations, 'isreject': event_isreject, } self.header['event'] = annotations # Parse the date string into a datetime object if hasattr(eeg, 'etc') and hasattr(eeg.etc, 'rec_startdate'): print(f"Found rec_startdate in EEG.etc: {eeg.etc.rec_startdate}") parsed_date = self._parse_start_date(eeg.etc.rec_startdate) # Make sure we update both places where the start time might be stored if parsed_date is not None: self.original_dataset.start_time = parsed_date except Exception as e: print(f"Error processing scipy metadata: {e}") def _process_h5py_metadata(self, eeglab_data): """Process metadata from h5py structure""" # Access the EEG structure eeg = eeglab_data.get('EEG', None) if eeg is None: print("Warning: Could not find EEG structure in the EEGLAB file") return # Extract additional metadata from the EEG structure try: print("Processing EEG metadata...") for attr in ['group', 'condition', 'session']: if attr in eeg: print(f"Extracting field: {attr}") value = eeg[attr] # Resolve HDF5 references value = self._resolve_h5py_value(value) if value is not None: # Skip None values self.header[attr] = value # Extract stages if available if 'etc' in eeg and 'stages' in eeg['etc']: print("Extracting stages") stages = eeg['etc']['stages'] stages = self._resolve_h5py_value(stages) if stages is not None: self.header['stages'] = stages #print(f" -> stages: {stages}") # Extract events if available if 'event' in eeg: print("Processing events...") events = eeg['event'] annotations = { 'onsets': [], 'types': [], 'durations': [], 'isreject': [], } for field, key in zip(['latency', 'type', 'duration','is_reject'], annotations): if field in events: value = self._resolve_h5py_value(events[field]) if value is not None: annotations[key] = value # print(f" -> event[{key}]: {value}") if any(len(v) > 0 for v in annotations.values()): self.header['event'] = annotations # Parse the date string into a datetime object if 'etc' in eeg and 'rec_startdate' in eeg['etc']: print("Extracting rec_startdate") rec_startdate = self._resolve_h5py_value(eeg['etc']['rec_startdate']) if rec_startdate: self._parse_start_date(rec_startdate) except Exception as e: print(f"Error processing h5py metadata: {e}") def _parse_start_date(self, rec_startdate): """Parse the recording start date with multiple format attempts""" try: # Print debug information print(f"Original rec_startdate type: {type(rec_startdate)}") if isinstance(rec_startdate, np.ndarray): print(f"Array shape: {rec_startdate.shape}, dtype: {rec_startdate.dtype}") # Special handling for uint16 arrays if isinstance(rec_startdate, np.ndarray) and rec_startdate.dtype == np.uint16: print("Processing uint16 date array") # Convert uint16 to characters, filtering out zeros chars = [] # Flatten array if it's multi-dimensional flat_array = rec_startdate.flatten() # Get only non-zero values (zeros are usually null terminators) for val in flat_array: if val != 0: try: chars.append(chr(val)) except ValueError: # Skip invalid Unicode code points pass # Join characters into a string date_str = ''.join(chars) print(f"Converted date string: {date_str}") # Continue with the normal date parsing using the converted string rec_startdate = date_str # Handle other different types that could come from h5py elif isinstance(rec_startdate, np.ndarray): # Convert ndarray to string if rec_startdate.dtype.kind in ['S', 'U']: # String or Unicode if rec_startdate.size == 1: # Single element string array rec_startdate = rec_startdate.item() else: # Array of strings, join them rec_startdate = b''.join(rec_startdate).decode('utf-8') if rec_startdate.dtype.kind == 'S' else ''.join(rec_startdate) elif rec_startdate.dtype.kind in ['i', 'u']: # Integer # Convert array of integers to string (ASCII/Unicode code points) if rec_startdate.ndim > 1: # Multi-dimensional array # Flatten array first flat_array = rec_startdate.flatten() # Filter out zeros and convert to characters char_array = [chr(int(x)) for x in flat_array if x != 0] else: # 1D array char_array = [chr(int(x)) for x in rec_startdate if x != 0] rec_startdate = ''.join(char_array) print(f"Converted character array to string: {rec_startdate}") # Handle bytes if isinstance(rec_startdate, bytes): rec_startdate = rec_startdate.decode('utf-8') # Make sure we have a string at this point if not isinstance(rec_startdate, str): print(f"Warning: Could not parse date, unexpected type after conversion: {type(rec_startdate)}") return print(f"Converted date string: {rec_startdate}") # Try common EEGLAB date formats date_formats = [ '%d-%b-%Y %H:%M:%S', # 01-Jan-2020 12:00:00 '%Y-%m-%d %H:%M:%S', # 2020-01-01 12:00:00 '%d.%m.%Y %H:%M:%S', # 01.01.2020 12:00:00 '%m/%d/%Y %H:%M:%S', # 01/01/2020 12:00:00 '%Y-%m-%dT%H:%M:%S', # 2020-01-01T12:00:00 ] # Try ISO format with timezone try: # For strings like '2019-06-17T19:19:55.256234+10:00' import dateutil.parser parsed_date = dateutil.parser.parse(rec_startdate) self.header['start_time'] = parsed_date print(f"Updated header start_time to: {parsed_date}") return parsed_date except (ImportError, ValueError): pass # Try the list of formats parsed_date = None for fmt in date_formats: try: parsed_date = datetime.strptime(rec_startdate, fmt) break except ValueError: continue if parsed_date is None: print(f"Warning: Could not parse date format: {rec_startdate}") return # Update the header with the correct start time self.header['start_time'] = parsed_date print(f"Updated header start_time to: {parsed_date}") except Exception as e: print(f"Error parsing start date: {e} (type: {type(rec_startdate)})") # Store as string if we couldn't process it if isinstance(rec_startdate, np.ndarray): try: self.header['date_array'] = rec_startdate.tolist() except: pass def _load_hdf5_data(self, hdf5_group, depth=0, max_depth=10, path="root"): """ Recursively load data from an HDF5 group into a nested dictionary. Parameters ---------- hdf5_group : h5py.Group The HDF5 group to load depth : int Current recursion depth max_depth : int Maximum recursion depth to prevent infinite recursion path : str Current path in the HDF5 file (for debugging) Returns ------- dict Nested dictionary containing the HDF5 data """ # Guard against excessive recursion if depth >= max_depth: print(f"Maximum recursion depth reached at {path}, stopping recursion") return {"max_depth_reached": True} result = {} try: # Get keys before iteration to avoid any potential modifications keys = list(hdf5_group.keys()) for key in keys: try: item = hdf5_group[key] new_path = f"{path}/{key}" if isinstance(item, h5py.Group): # Recursively process group with depth tracking result[key] = self._load_hdf5_data(item, depth + 1, max_depth, new_path) elif isinstance(item, h5py.Dataset): # Check for large datasets to avoid memory issues size_mb = np.prod(item.shape) * item.dtype.itemsize / (1024*1024) if hasattr(item, 'shape') else 0 if size_mb > 100: # Skip loading datasets larger than 100MB print(f"Skipping large dataset {new_path}: {size_mb:.2f} MB") result[key] = { "shape": item.shape, "dtype": str(item.dtype), "size_mb": size_mb, "large_dataset": True } elif item.dtype == h5py.ref_dtype and item.size > 10000: # Handle large reference arrays print(f"Large reference array detected at {new_path}: {item.size} references") result[key] = { "shape": item.shape, "dtype": str(item.dtype), "reference_count": item.size, "large_reference_array": True } else: # Load normal datasets result[key] = item[()] except Exception as e: # Handle errors for specific items print(f"Error loading {path}/{key}: {e}") result[key] = {"error": str(e)} except Exception as e: # Handle errors for the entire group print(f"Error processing HDF5 group at {path}: {e}") return {"error": str(e)} return result # result = {} # for key, item in hdf5_group.items(): # if isinstance(item, h5py.Group): # result[key] = self._load_hdf5_data(item) # elif isinstance(item, h5py.Dataset): # result[key] = item[()] # return result def _resolve_h5py_value(self, value, _depth=0, _max_depth=10): """ Resolve HDF5 references in a value, which may be a single reference, an array of references, or a list. Parameters ---------- value : any The value to resolve, which might contain HDF5 references Returns ------- resolved_value : any The resolved value """ # Circuit breaker to prevent infinite recursion if _depth > _max_depth: print(f"WARNING: Maximum recursion depth reached ({_depth}/{_max_depth}), stopping resolution") return None try: # Debug info #print(f"Resolving value of type: {type(value)}") # if isinstance(value, np.ndarray): # print(f" Array shape: {value.shape}, dtype: {value.dtype}") # if value.size > 0: # print(f" First few elements: {value.flatten()[:min(5, value.size)]}") # Detect null reference arrays if isinstance(value, np.ndarray) and value.dtype == np.uint64 and np.all(value == 0): return None # Handle direct reference if isinstance(value, h5py.Reference): return self._resolve_single_reference(value) # Handle array of references if isinstance(value, np.ndarray) and value.dtype == object: result = [] for ref in value.flatten(): if isinstance(ref, h5py.Reference) and ref: resolved = self._resolve_single_reference(ref) result.append(resolved) else: result.append(None) return result # Handle lists with circuit breaker if isinstance(value, list): return [self._resolve_h5py_value(item, _depth + 1, _max_depth) for item in value] # Special handling for uint16 arrays (like date strings, filenames, subjects) if isinstance(value, np.ndarray) and value.dtype == np.uint16: # Convert uint16 array to string try: text = ''.join(chr(c) for c in value.flatten() if c != 0) return text except: return value.tolist() # Simplify (N, 1) or (1, N) arrays if isinstance(value, np.ndarray) and value.size == 1: return value.item() # Return other numpy arrays as lists if isinstance(value, np.ndarray): return value.tolist() return value except Exception as e: print(f"Error resolving HDF5 value: {e}") return value def _resolve_single_reference(self, reference): """ Resolve a single HDF5 object reference. Parameters ---------- reference : h5py.Reference The HDF5 object reference to resolve Returns ------- data The resolved data """ try: # We need to reopen the file to resolve references with h5py.File(self.filename, 'r') as f: # Get the referenced object obj = f[reference] # Return the data data = obj[()] #print(f" Resolving object {obj.name}, dtype={data.dtype if isinstance(data, np.ndarray) else type(data)}") # Special handling for string data if isinstance(data, np.ndarray): if data.dtype.kind in ['S', 'U']: # Convert bytes to strings if needed if data.dtype.kind == 'S': # Handle single string or array of strings if data.size == 1: return data.item().decode('utf-8') else: return [s.decode('utf-8') if isinstance(s, bytes) else s for s in data] else: return data.tolist() # Unicode strings elif data.dtype ==np.uint8: if data.size == 1: return int(data.item()) # Return 0 or 1 try: return data.tobytes().decode('utf-8').rstrip('\x00') except UnicodeDecodeError: return [int(x) for x in data.flatten()] elif data.dtype == np.uint16: # Convert uint16 array to string (ASCII codes) try: return ''.join(chr(c) for c in data.flatten() if c != 0) except: return data.tolist() elif data.dtype.kind in 'f': # Handle floating-point data if data.size == 1: return data.item() # Return as a Python float else: return data.tolist() # Return as a list of floats # Convert numpy arrays to lists for better serialization if isinstance(data, np.ndarray): return data.tolist() return data except Exception as e: print(f"Error resolving single HDF5 reference: {e}") return None
[docs] def create_memmap(self, memmap_dir=None): """Create a memory-mapped version of the data for faster access""" if memmap_dir is None: memmap_dir = self.filename.parent else: memmap_dir = Path(memmap_dir) memmap_path = memmap_dir / f"{self.filename.stem}_memmap.dat" info_path = memmap_dir / f"{self.filename.stem}_memmap.json" # Get dataset dimensions n_channels = len(self.channels) n_samples = self.header.get('n_samples', int(self.header.get('recording_duration', 8*3600) * self.sampling_rate)) # Create memory-mapped file shape = (n_channels, n_samples) mmap = np.memmap(memmap_path, dtype=np.float32, mode='w+', shape=shape) # Fill the memmap file in chunks chunk_size = 60 * self.sampling_rate # 1 minute of data chunks = n_samples // chunk_size + (1 if n_samples % chunk_size > 0 else 0) print(f"Creating memory map with {chunks} chunks...") for i in range(chunks): start_sample = i * chunk_size end_sample = min((i + 1) * chunk_size, n_samples) start_time = start_sample / self.sampling_rate end_time = end_sample / self.sampling_rate print(f"Processing chunk {i+1}/{chunks}: {start_time:.1f}s - {end_time:.1f}s") # Read chunk from original data try: data = self.original_dataset.read_data(begtime=start_time, endtime=end_time) mmap[:, start_sample:end_sample] = data.data[0] except Exception as e: print(f"Error processing chunk {i+1}: {e}") # Flush to disk mmap.flush() # Create info file self.memmap_info = { 'filepath': str(memmap_path), 'shape': shape, 'channels': self.channels, 'sampling_rate': self.sampling_rate, 'dtype': 'float32' } # Save for later use with open(info_path, 'w') as f: json.dump(self.memmap_info, f) print(f"Memory map created at {memmap_path}") return self.memmap_info
[docs] def read_data(self, begtime=None, endtime=None, chan=None): """ Read data from the dataset, using memory map if available Parameters ---------- begtime : float or None Start time in seconds endtime : float or None End time in seconds chan : list or None List of channels to load Returns ------- data : ndarray Array containing the requested data """ if self.memmap_info is not None: # Use memory map for faster access return self._read_from_memmap(begtime, endtime, chan) else: # Fall back to original Wonambi method return self.original_dataset.read_data(begtime=begtime, endtime=endtime, chan=chan)
def _read_from_memmap(self, begtime=None, endtime=None, chan=None): """Read data from memory map""" mmap_path = self.memmap_info['filepath'] shape = tuple(self.memmap_info['shape']) # Open memory map mmap_data = np.memmap(mmap_path, dtype=np.float32, mode='r', shape=shape) # Calculate indices start_idx = 0 if begtime is None else int(begtime * self.sampling_rate) end_idx = shape[1] if endtime is None else int(endtime * self.sampling_rate) # Get channel indices if chan is None: chan_indices = slice(None) # All channels else: # Convert channel names to indices if needed if isinstance(chan[0], str): chan_indices = [self.channels.index(ch) for ch in chan if ch in self.channels] else: chan_indices = chan # Get data slice data = mmap_data[chan_indices, start_idx:end_idx] # Create a copy to avoid reference issues when memmap is closed data_copy = data.copy() # Format similar to Wonambi's output from wonambi.datatype import ChanTime output = ChanTime() output.data = np.array([data_copy]) output.axis['chan'] = [self.channels[i] for i in chan_indices] if isinstance(chan_indices, list) else self.channels output.axis['time'] = np.arange(start_idx, end_idx) / self.sampling_rate output.s_freq = self.sampling_rate return output