Source code for turtlewave_hdEEG.extensions

"""
Custom extensions to Wonambi spindle detection
"""

from numpy import mean, arange
from wonambi.detect import DetectSpindle as OriginalDetectSpindle
from wonambi.detect import DetectSlowWave as OriginalDetectSlowWave



[docs] class ImprovedDetectSpindle(OriginalDetectSpindle): def __init__(self, method='Moelle2011', frequency=None, duration=None, det_thresh=None, sel_thresh=None, moving_rms=None, smooth_dur=None, tolerance=None, min_interval=None, merge=False, polar='normal', **kwargs): """ Initialize improved spindle detection. Parameters ---------- method : str Detection method. Supported methods include: 'Ferrarelli2007', 'Moelle2011', 'Nir2011', 'Wamsley2012', 'Martin2013', 'Ray2015', 'Lacourse2018' frequency : tuple of float Frequency range for spindle detection (low and high) duration : tuple of float Duration range for spindles in seconds (min and max) det_thresh : float or None Detection threshold (method-specific units) sel_thresh : float or None Selection threshold (method-specific units) moving_rms : dict or float or None Parameters for moving RMS, format: {'dur': float, 'step': float or None} or just duration as float smooth_dur : float or None Duration for smoothing window in seconds tolerance : float or None Tolerance for merging events in seconds min_interval : float or None Minimum interval between events in seconds merge : bool If True, merge events across channels polar : str Signal polarity - 'normal' or 'opposite' **kwargs : dict Additional method-specific parameters """ # Call parent constructor super().__init__(method, frequency, duration, merge) # Store signal inversion if polar == 'normal': self.invert = False elif polar == 'opposite': self.invert = True # Store parameters that will be applied after default initialization self._custom_params = { 'det_thresh': det_thresh, 'sel_thresh': sel_thresh, 'moving_rms_dur': moving_rms, 'smooth_dur': smooth_dur, 'tolerance': tolerance, 'min_interval': min_interval, **kwargs # Include any other custom parameters } # Set method-specific parameters self._set_method_params() # Apply custom parameters to override defaults self._apply_custom_parameters() def _set_method_params(self): """Set parameters specific to each detection method.""" if self.method == 'Ferrarelli2007': if not hasattr(self, 'frequency') or self.frequency is None: self.frequency = (11, 15) if not hasattr(self, 'duration') or self.duration is None: self.duration = (0.3, 3) self.det_remez = {'freq': self.frequency, 'rolloff': 0.9, 'dur': 2.56, 'step': None } self.det_thresh = 8 self.sel_thresh = 2 elif self.method == 'Moelle2011': if not hasattr(self, 'frequency') or self.frequency is None: self.frequency = (12, 15) if not hasattr(self, 'duration') or self.duration is None: self.duration = (0.5, 3) self.det_remez = {'freq': self.frequency, 'rolloff': 1.7, 'dur': 2.36, 'step': None } self.moving_rms = {'dur': .2, 'step': None} self.smooth = {'dur': .2, 'win': 'flat'} self.det_thresh = 1.5 elif self.method == 'Nir2011': if not hasattr(self, 'frequency') or self.frequency is None: self.frequency = (9.2, 16.8) if not hasattr(self, 'duration') or self.duration is None: self.duration = (0.5, 2) self.det_butter = {'order': 2, 'freq': self.frequency, 'step': None } self.tolerance = 1 self.smooth = {'dur': .04} # is in fact sigma self.det_thresh = 3 self.sel_thresh = 1 elif self.method == 'Wamsley2012': if not hasattr(self, 'frequency') or self.frequency is None: self.frequency = (12, 15) if not hasattr(self, 'duration') or self.duration is None: self.duration = (0.3, 3) self.det_wavelet = {'f0': mean(self.frequency), 'sd': .8, 'dur': 1., 'output': 'complex', 'step': None } self.smooth = {'dur': .1, 'win': 'flat'} self.det_thresh = 4.5 elif self.method == 'Martin2013': if not hasattr(self, 'frequency') or self.frequency is None: self.frequency = (11.5, 14.5) if not hasattr(self, 'duration') or self.duration is None: self.duration = (.5, 3) self.det_remez = {'freq': self.frequency, 'rolloff': 1.1, 'dur': 2.56, 'step': None } self.moving_rms = {'dur': .25, 'step': .25} self.det_thresh = 95 elif self.method == 'Ray2015': if not hasattr(self, 'frequency') or self.frequency is None: self.frequency = (11, 16) if not hasattr(self, 'duration') or self.duration is None: self.duration = (.49, None) self.cdemod = {'freq': mean(self.frequency)} self.det_butter = {'freq': (0.3, 35), 'order': 4, 'step': None} self.det_low_butter = {'freq': 5, 'order': 4, 'step': None} self.min_interval = 0.25 # they only start looking again after .25s self.smooth = {'dur': 2 / self.cdemod['freq'], 'win': 'triangle'} self.zscore = {'dur': 60, 'step': None, 'pcl_range': None} self.det_thresh = 2.33 self.sel_thresh = 0.1 elif self.method == 'Lacourse2018': if not hasattr(self, 'frequency') or self.frequency is None: self.frequency = (11, 16) if not hasattr(self, 'duration') or self.duration is None: self.duration = (.3, 2.5) self.det_butter = {'freq': self.frequency, 'order': 20, 'step': None} self.det_butter2 = {'freq': (.3, 30), 'order': 5, 'step': None} self.windowing = {'dur': .3, 'step': .1} win = self.windowing self.moving_ms = {'dur': win['dur'], 'step': win['step']} self.moving_power_ratio = {'dur': win['dur'], 'step': win['step'], 'freq_narrow': self.frequency, 'freq_broad': (4.5, 30), 'fft_dur': 2} self.zscore = {'dur': 30, 'step': None, 'pcl_range': (10, 90)} self.moving_covar = {'dur': win['dur'], 'step': win['step']} self.moving_sd = {'dur': win['dur'], 'step': win['step']} self.smooth = {'dur': 0.3, 'win': 'flat_left'} self.abs_pow_thresh = 1.25 self.rel_pow_thresh = 1.6 self.covar_thresh = 1.3 self.corr_thresh = 0.69 else: raise ValueError(f'Unknown method: {self.method}') # Safety checks for all methods - include step parameter checks here for param_name in ['moving_rms', 'moving_ms', 'moving_power_ratio', 'moving_covar', 'moving_sd', 'windowing', 'zscore', 'det_butter', 'det_remez', 'det_wavelet']: if hasattr(self, param_name) and isinstance(getattr(self, param_name), dict): param_dict = getattr(self, param_name) if 'step' not in param_dict: param_dict['step'] = None def _ensure_step_parameters(self): """ Ensure all required parameters exist in method dictionaries with comprehensive check. """ # Get all attributes of self that are dictionaries for attr_name in dir(self): # Skip private attributes and non-data attributes if attr_name.startswith('_') or callable(getattr(self, attr_name)): continue attr = getattr(self, attr_name) # Check if it's a dictionary if isinstance(attr, dict): # If it's a nested dictionary that contains parameters if any(k in attr for k in ['dur', 'freq', 'order']): if 'step' not in attr: attr['step'] = None # Ensure pcl_range exists for zscore dictionaries if attr_name == 'zscore' or (isinstance(attr, dict) and 'dur' in attr and 'pcl_range' not in attr): attr['pcl_range'] = None # Handle other common missing parameters if 'freq' in attr and isinstance(attr['freq'], tuple) and 'rolloff' not in attr and attr_name.startswith('det_'): attr['rolloff'] = 0.5 # Handle moving_power_ratio parameters if attr_name == 'moving_power_ratio' or (isinstance(attr, dict) and 'dur' in attr and ('freq_narrow' not in attr or 'freq_broad' not in attr)): # Add default parameters for moving_power_ratio if 'freq_narrow' not in attr: attr['freq_narrow'] = self.frequency if hasattr(self, 'frequency') else (11, 16) if 'freq_broad' not in attr: attr['freq_broad'] = (4.5, 30) if 'fft_dur' not in attr: attr['fft_dur'] = 2 # handle dictionaries in list attributes elif isinstance(attr, list): for item in attr: if isinstance(item, dict): if any(k in item for k in ['dur', 'freq', 'order']): if 'step' not in item: item['step'] = None if 'dur' in item and 'pcl_range' not in item: item['pcl_range'] = None if 'sd' in item and 'output' not in item: item['output'] = 'complex' # Specific method checks if self.method == 'Ray2015' and hasattr(self, 'zscore'): if 'pcl_range' not in self.zscore: self.zscore['pcl_range'] = None if self.method == 'Wamsley2012' and hasattr(self, 'det_wavelet'): if 'f0' not in self.det_wavelet: self.det_wavelet['f0'] = mean(self.frequency) if 'output' not in self.det_wavelet: self.det_wavelet['output'] = 'complex' # Lacourse2018-specific checks if self.method == 'Lacourse2018' and hasattr(self, 'moving_power_ratio'): # Ensure all required parameters exist if 'freq_narrow' not in self.moving_power_ratio: self.moving_power_ratio['freq_narrow'] = self.frequency if 'freq_broad' not in self.moving_power_ratio: self.moving_power_ratio['freq_broad'] = (4.5, 30) if 'fft_dur' not in self.moving_power_ratio: self.moving_power_ratio['fft_dur'] = 2 def _apply_custom_parameters(self): """Apply custom parameters, overriding defaults""" # Simple parameter overrides if self._custom_params['det_thresh'] is not None: self.det_thresh = self._custom_params['det_thresh'] if self._custom_params['sel_thresh'] is not None and hasattr(self, 'sel_thresh'): self.sel_thresh = self._custom_params['sel_thresh'] if self._custom_params['tolerance'] is not None: self.tolerance = self._custom_params['tolerance'] if self._custom_params['min_interval'] is not None: self.min_interval = self._custom_params['min_interval'] # Update moving RMS duration if provided if self._custom_params['moving_rms_dur'] is not None and hasattr(self, 'moving_rms'): # Handle both dictionary and float inputs for moving_rms if isinstance(self._custom_params['moving_rms_dur'], dict): if 'dur' in self._custom_params['moving_rms_dur']: self.moving_rms['dur'] = self._custom_params['moving_rms_dur']['dur'] if 'step' in self._custom_params['moving_rms_dur']: self.moving_rms['step'] = self._custom_params['moving_rms_dur']['step'] else: # If just a float is provided, assume it's the duration self.moving_rms['dur'] = self._custom_params['moving_rms_dur'] # Update smooth duration if provided if self._custom_params['smooth_dur'] is not None and hasattr(self, 'smooth'): self.smooth['dur'] = self._custom_params['smooth_dur'] # Method-specific parameters if self.method == 'Lacourse2018': if 'abs_pow_thresh' in self._custom_params: self.abs_pow_thresh = self._custom_params['abs_pow_thresh'] if 'rel_pow_thresh' in self._custom_params: self.rel_pow_thresh = self._custom_params['rel_pow_thresh'] if 'covar_thresh' in self._custom_params: self.covar_thresh = self._custom_params['covar_thresh'] if 'corr_thresh' in self._custom_params: self.corr_thresh = self._custom_params['corr_thresh'] if 'window_dur' in self._custom_params and self._custom_params['window_dur'] is not None: # Update all window durations win_dur = self._custom_params['window_dur'] for attr_name in ['windowing', 'moving_ms', 'moving_power_ratio', 'moving_covar', 'moving_sd']: if hasattr(self, attr_name): attr = getattr(self, attr_name) if isinstance(attr, dict): # Set step equal to dur/2 if not specified (common default) if 'step' not in attr or attr['step'] is None: if 'dur' in attr: attr['step'] = attr['dur'] / 2 elif self.method == 'Ray2015': if 'zscore_dur' in self._custom_params and self._custom_params['zscore_dur'] is not None: if hasattr(self, 'zscore'): self.zscore['dur'] = self._custom_params['zscore_dur'] # Always ensure step is present if 'step' not in self.zscore: self.zscore['step'] = None elif self.method == 'Wamsley2012': if 'wavelet_sd' in self._custom_params and self._custom_params['wavelet_sd'] is not None: if hasattr(self, 'det_wavelet'): self.det_wavelet['sd'] = self._custom_params['wavelet_sd'] if 'wavelet_dur' in self._custom_params and self._custom_params['wavelet_dur'] is not None: if hasattr(self, 'det_wavelet'): self.det_wavelet['dur'] = self._custom_params['wavelet_dur'] # Always ensure f0 is present for Wamsley2012 if hasattr(self, 'det_wavelet'): self.det_wavelet['f0'] = mean(self.frequency) # Always ensure step is present if 'step' not in self.det_wavelet: self.det_wavelet['step'] = None # Apply any additional custom parameters for key, value in self._custom_params.items(): if hasattr(self, key) and value is not None: setattr(self, key, value) self._ensure_step_parameters()
[docs] def __call__(self, data, parent=None): # 5 minutes timeout """ Detect spindles in the data with optional signal inversion. Parameters ---------- data : instance of Data The data to analyze parent : QWidget For use with GUI, as parent widget for the progress bar timeout : int Maximum time in seconds to allow for detection before timing out Returns ------- instance of graphoelement.Spindles Detected spindles """ # Add comprehensive check for step parameters right before detection self._ensure_step_parameters() # Check if we need to invert the signal if hasattr(self, 'invert') and self.invert: # Make a copy to avoid modifying the original data_copy = data.copy() # Invert signal for all epochs for i in range(len(data_copy.data)): data_copy.data[i] = -data_copy.data[i] return super().__call__(data_copy, parent) else: # No inversion needed, call parent method directly return super().__call__(data, parent)
[docs] class ImprovedDetectSlowWave(OriginalDetectSlowWave): def __init__(self, method='Massimini2004', frequency=None, duration=None, neg_peak_thresh=40, p2p_thresh=75, min_dur=None, max_dur=None, polar='normal'): """ Initialize improved slow wave detection. Parameters ---------- method : str Detection method. Supported methods: - 'Massimini2004': Traditional threshold-based detection - 'AASM/Massimini2004': AASM criteria with Massimini method - 'Ngo2015': Detection based on Ngo et al. 2015 - 'Staresina2015': Detection based on Staresina et al. 2015 frequency : tuple of float Frequency range for slow wave detection duration : tuple of float Duration range for slow waves in seconds (used for trough_duration in Massimini methods) neg_peak_thresh : float Minimum negative peak amplitude in μV p2p_thresh : float Minimum peak-to-peak amplitude in μV min_dur : float or None Minimum duration of a slow wave in seconds (used for Ngo2015 and Staresina2015) max_dur : float or None Maximum duration of a slow wave in seconds (used for Ngo2015 and Staresina2015) polar : str Signal polarity - 'normal' or 'opposite' """ super().__init__(method, duration) # Store additional parameters self.min_neg_amp = neg_peak_thresh self.min_ptp_amp = p2p_thresh if polar == 'normal': self.invert = False elif polar == 'opposite': self.invert = True # Store duration parameters self.min_dur_param = min_dur self.max_dur_param = max_dur # Override frequency if provided if frequency is not None: if method in ['Massimini2004', 'AASM/Massimini2004']: self.det_filt['freq'] = frequency elif method in ['Ngo2015', 'Staresina2015']: self.lowpass['freq'] = frequency[1] # Use upper bound self.det_filt['freq'] = frequency # Set method-specific parameters self._set_method_params() def _set_method_params(self): """Set parameters specific to each detection method.""" if self.method == 'Massimini2004': if not hasattr(self, 'det_filt'): self.det_filt = { 'order': 2, 'freq': (0.1, 4.0) } # Use default values unless overridden self.trough_duration = (0.3, 1.0) self.max_trough_amp = -80 self.min_ptp = 140 self.min_dur = 0 self.max_dur = None elif self.method == 'AASM/Massimini2004': if not hasattr(self, 'det_filt'): self.det_filt = { 'order': 2, 'freq': (0.1, 1.0) } # Use default values unless overridden self.trough_duration = (0.25, 1.0) self.max_trough_amp = -37 self.min_ptp = 70 self.min_dur = 0 self.max_dur = None elif self.method == 'Ngo2015': if not hasattr(self, 'lowpass'): self.lowpass = { 'order': 2, 'freq': 3.5 } # Use provided min_dur and max_dur if available, otherwise use defaults self.min_dur = 0.833 if self.min_dur_param is None else self.min_dur_param self.max_dur = 2.0 if self.max_dur_param is None else self.max_dur_param if not hasattr(self, 'det_filt'): self.det_filt = { 'freq': (1 / self.max_dur, 1 / self.min_dur) } self.peak_thresh = 1.25 self.ptp_thresh = 1.25 elif self.method == 'Staresina2015': if not hasattr(self, 'lowpass'): self.lowpass = { 'order': 3, 'freq': 1.25 } # Use provided min_dur and max_dur if available, otherwise use defaults self.min_dur = 0.8 if self.min_dur_param is None else self.min_dur_param self.max_dur = 2.0 if self.max_dur_param is None else self.max_dur_param if not hasattr(self, 'det_filt'): self.det_filt = { 'freq': (1 / self.max_dur, 1 / self.min_dur) } self.ptp_thresh = 75 else: raise ValueError('Method must be one of: Massimini2004, AASM/Massimini2004, Ngo2015, or Staresina2015') # Always update filter frequency based on min_dur and max_dur for these methods if self.method in ['Ngo2015', 'Staresina2015'] and self.min_dur > 0 and self.max_dur > 0: self.det_filt['freq'] = (1 / self.max_dur, 1 / self.min_dur)
[docs] def __call__(self, data): """ Detect slow waves in the data. Parameters ---------- data : instance of Data The data to analyze Returns ------- instance of graphoelement.SlowWaves Detected slow waves """ # Invert signal if requested if self.invert: data.data[0][0] = -data.data[0][0] # Run detection using parent class events = super().__call__(data) # Apply additional amplitude criteria if needed filtered_events = [] for evt in events: if (abs(evt['trough_val']) >= self.min_neg_amp and abs(evt['ptp']) >= self.min_ptp_amp): filtered_events.append(evt) # Update events events.events = filtered_events return events