Source code for turtlewave_hdEEG.swprocessor

import numpy as np
import time
import os
#import multiprocessing
import csv
from wonambi.trans import select, fetch, math
from wonambi.attr import Annotations
from turtlewave_hdEEG.extensions import ImprovedDetectSlowWave as DetectSlowWave
import json
import datetime
import logging


[docs] class ParalSWA: """ A class for parallel detection and analysis of slow wave activity (SWA) across multiple channels. """ def __init__(self, dataset, annotations=None, log_level=logging.INFO, log_file=None): """ Initialize the ParalSWA object. Parameters ---------- dataset : Dataset Dataset object containing EEG data annotations : XLAnnotations Annotations object for storing and retrieving events log_level : int Logging level (e.g., logging.DEBUG, logging.INFO) log_file : str or None Path to log file. If None, logs to console only. """ self.dataset = dataset self.annotations = annotations # Setup logging self.logger = self._setup_logger(log_level, log_file) def _setup_logger(self, log_level, log_file=None): """ Set up a logger for the SWAProcessor. Parameters ---------- log_level : int Logging level (e.g., logging.DEBUG, logging.INFO) log_file : str or None Path to log file. If None, logs to console only. Returns ------- logger : logging.Logger Configured logger instance """ # Create a logger logger = logging.getLogger('turtlewave_hdEEG.swaprocessor') logger.setLevel(log_level) # Create formatter formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') # Create console handler console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) # Create file handler if log_file specified if log_file: file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger
[docs] def clean_memory(self): """ Perform thorough memory cleanup to release resources """ import gc import sys # Clear any large variables in the class if hasattr(self, '_temp_data'): del self._temp_data # Force garbage collection gc.collect() # For more aggressive cleanup on systems that support it if sys.platform == 'linux': try: import resource import psutil # Suggest to OS to release memory psutil.Process().memory_info() resource.RUSAGE_SELF except ImportError: self.logger.info("psutil not available for advanced memory cleanup") self.logger.info("Memory cleanup performed")
[docs] def detect_slow_waves(self, method='Massimini2004', chan=None, ref_chan=[], grp_name='eeg', frequency=(0.1, 4), trough_duration=(0.3, 1.5), neg_peak_thresh=-80.0, p2p_thresh=140.0, min_dur=None, max_dur=None, detrend=False, polar='normal', # normal vs opposite reject_artifacts=True, reject_arousals=True, stage=None, cat=None, peak_thresh_sigma=None, ptp_thresh_sigma=None, save_to_annotations=False, json_dir=None, create_empty_json=True): """ Detect slow waves in the dataset while considering artifacts and arousals. Parameters ---------- method : str or list Detection method(s) to use ('Massimini2004', 'AASM/Massimini2004', 'Ngo2015', 'Staresina2015') chan : list or str Channels to analyze ref_chan : list or str Reference channel(s) for re-referencing grp_name : str Group name for channel selection frequency : tuple Frequency range for slow wave detection (min, max) trough_duration : tuple Duration range for slow wave trough in seconds (min, max) neg_peak_thresh : float Minimum negative peak threshold in μV p2p_thresh : float Minimum peak-to-peak amplitude threshold in μV peak_thresh_sigma : float or None Peak threshold in standard deviations (for Ngo2015 method) ptp_thresh_sigma : float or None Peak-to-peak threshold in standard deviations (for Ngo2015 method) invert : bool Whether to invert the signal polarity reject_artifacts : bool Whether to exclude segments marked with artifact annotations reject_arousals : bool Whether to exclude segments marked with arousal annotations stage : list or str Sleep stage(s) to analyze cat : tuple Category specification for data selection save_to_annotations : bool Whether to save detected slow waves to annotations json_dir : str or None Directory to save individual channel JSON files Returns ------- list List of all detected slow waves """ import uuid self.logger.info(r""" ___ __,__,__,__, /_@ \ / / \ \ \ \__\/-<_>-<_>-<->-|-< /\____________/~ / /===/ /=====\ \ "" "" "" '''' searching for slow waves... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """) # Validate polar parameter if polar not in ['normal', 'opposite']: self.logger.warning(f"Invalid polar value '{polar}'. Using 'normal'.") polar = 'normal' # Configure what to reject reject_types = [] if reject_artifacts: reject_types.append('Artefact') self.logger.debug("Configured to reject artifacts") if reject_arousals: reject_types.extend(['Arousal']) self.logger.debug("Configured to reject arousals") # Make sure method is a list if isinstance(method, str): method = [method] # Make sure chan is a list if isinstance(chan, str): chan = [chan] # Make sure stage is a list if isinstance(stage, str): stage = [stage] # Create json_dir if specified if json_dir: os.makedirs(json_dir, exist_ok=True) self.logger.info(f"Channel JSONs will be saved to: {json_dir}") # Verify required components if self.dataset is None: self.logger.error("Error: No dataset provided for slow wave detection") return [] if self.annotations is None and save_to_annotations: self.logger.warning("Warning: No annotations provided but annotation saving requested.") self.logger.warning("Slow waves will not be saved to annotations.") save_to_annotations = False # Convert method to string method_str = "_".join(method).replace('/', '_') if isinstance(method, list) else str(method).replace('/', '_') # Convert frequency to string freq_str = f"{frequency[0]}-{frequency[1]}Hz" self.logger.info(f"Starting slow wave detection with method={method_str}, frequency={freq_str}") self.logger.debug(f"Parameters: channels={chan}, reject_artifacts={reject_artifacts}, reject_arousals={reject_arousals}") # Log adaptive threshold parameters if applicable first_method = method[0] if isinstance(method, list) and len(method) > 0 else method if first_method == 'Ngo2015' and peak_thresh_sigma is not None and ptp_thresh_sigma is not None: self.logger.info(f"Using adaptive thresholds: peak_thresh_sigma={peak_thresh_sigma}, ptp_thresh_sigma={ptp_thresh_sigma}") # Create custom annotation file name if saving to annotations if save_to_annotations: # Convert channel list to string chan_str = "_".join(chan) if len(chan) <= 3 else f"{chan[0]}_plus_{len(chan)-1}_chans" # Create custom filename annotation_filename = f"slowwaves_{method_str}_{chan_str}_{freq_str}.xml" # Create full path if json_dir is specified if json_dir: annotation_file_path = os.path.join(json_dir, annotation_filename) else: # Use current directory annotation_file_path = annotation_filename # Create new annotation object if we're saving to a new file if self.annotations is not None: try: # Create a copy of the original annotations import shutil if hasattr(self.annotations, 'xml_file') and os.path.exists(self.annotations.xml_file): shutil.copy(self.annotations.xml_file, annotation_file_path) new_annotations = Annotations(annotation_file_path) try: sw_events = new_annotations.get_events('slow_wave') if sw_events: self.logger.info(f"Removing {len(sw_events)} existing slow wave events") new_annotations.remove_event_type('slow_wave') except Exception as e: self.logger.error(f"Note: No existing slow wave events to remove: {e}") else: # Create new annotations file from scratch with open(annotation_file_path, 'w') as f: f.write('<?xml version="1.0" ?>\n<annotations><dataset><filename>') if hasattr(self.dataset, 'filename'): f.write(self.dataset.filename) f.write('</filename></dataset><rater><name>Wonambi</name></rater></annotations>') new_annotations = Annotations(annotation_file_path) print(f"Will save slow waves to new annotation file: {annotation_file_path}") except Exception as e: self.logger.error(f"Error creating new annotation file: {e}") save_to_annotations = False new_annotations = None else: self.logger.warning("Warning: No annotations provided but annotation saving requested.") self.logger.error("Slow waves will not be saved to annotations.") save_to_annotations = False new_annotations = None # Store all detected slow waves all_slow_waves = [] for ch in chan: try: self.logger.info(f'Reading data for channel {ch}') # Fetch segments, filtering based on stage and artifacts segments = fetch(self.dataset, self.annotations, cat=cat, stage=stage, cycle=None, reject_epoch=True, reject_artf=reject_types) segments.read_data(ch, ref_chan, grp_name=grp_name) # Process each detection method channel_slow_waves = [] channel_json_slow_waves = [] ## Loop through methods for m, meth in enumerate(method): self.logger.info(f"Applying method: {meth}") for i, seg in enumerate(segments): self.logger.info(f'Detecting events, segment {i + 1} of {len(segments)}') # Create a copy of the segment for processing processed_seg = seg.copy() # Apply polarity adjustment if needed if polar == 'opposite': processed_seg['data'].data[0][0] = -processed_seg['data'].data[0][0] elif polar == 'normal': pass self.logger.debug(f'Applied polarity inversion to segment {i + 1}') if detrend: self.logger.debug(f'Applying detrend to segment {i + 1}') try: processed_seg['data'] = math(processed_seg['data'], operator='detrend', axis='time') except Exception as e: self.logger.error(f"Error detrending data: {e}") # Special handling for Ngo2015 with adaptive thresholds detection_kwargs = {} if meth == 'Ngo2015' and peak_thresh_sigma is not None and ptp_thresh_sigma is not None: # Store sigma thresholds as class variables that the detector will use detection_kwargs = { 'peak_thresh': peak_thresh_sigma, 'ptp_thresh': ptp_thresh_sigma } self.logger.debug(f"Using custom adaptive thresholds: {detection_kwargs}") # Define detection with parameters detection = DetectSlowWave( meth, frequency=frequency, # Use appropriate duration parameter based on method duration=trough_duration if meth in ['Massimini2004', 'AASM/Massimini2004'] else None, neg_peak_thresh=neg_peak_thresh, p2p_thresh=p2p_thresh, min_dur=min_dur if meth not in ['Massimini2004', 'AASM/Massimini2004'] else None, max_dur=max_dur if meth not in ['Massimini2004', 'AASM/Massimini2004'] else None, polar=polar, **detection_kwargs # Pass method-specific kwargs ) # Run detection slow_waves = detection(processed_seg['data']) if slow_waves and save_to_annotations and new_annotations is not None: slow_waves.to_annot(new_annotations, 'slow_wave') # Add to our results # Convert to dictionary format for consistency for sw in slow_waves: # Add UUID to each slow wave sw['uuid'] = str(uuid.uuid4()) # Add channel information sw['chan'] = ch channel_slow_waves.append(sw) # Add to JSON if json_dir: # Extract key properties in a serializable format sw_data = { 'uuid': sw['uuid'], 'chan': ch, 'start_time': float(sw.get('start', 0)), 'end_time': float(sw.get('end', 0)), 'trough_time': float(sw.get('trough_time', 0)), 'peak_time': float(sw.get('peak_time', 0)), 'duration': float(sw.get('dur', 0)), 'trough_val': float(sw.get('trough_val', 0)), 'peak_val': float(sw.get('peak_val', 0)), 'ptp': float(sw.get('ptp', 0)), 'method': meth } sw_data['stage'] = stage sw_data['freq_range'] = frequency channel_json_slow_waves.append(sw_data) all_slow_waves.extend(channel_slow_waves) self.logger.info(f"Found {len(channel_slow_waves)} slow waves in channel {ch}") stages_str = "".join(stage) if stage else "all" if json_dir : try: ch_json_file = os.path.join(json_dir, f"slowwaves_{method_str}_{freq_str}_{stages_str}_{ch}.json") # Create empty JSON if no waves found but flag is set if not channel_json_slow_waves and create_empty_json: self.logger.info(f"Creating empty JSON file for channel {ch} (no slow waves detected)") with open(ch_json_file, 'w') as f: json.dump([], f) elif channel_json_slow_waves: with open(ch_json_file, 'w') as f: json.dump(channel_json_slow_waves, f, indent=2) self.logger.info(f"Saved slow wave data for channel {ch} to {ch_json_file}") except Exception as e: self.logger.error(f"Error saving channel JSON: {e}") except Exception as e: self.logger.warning(f'WARNING: No slow waves in channel {ch}: {e}') # Create empty JSON file even in case of error if json_dir and create_empty_json: try: stages_str = "".join(stage) if stage else "all" ch_json_file = os.path.join(json_dir, f"slowwaves_{method_str}_{freq_str}_{stages_str}_{ch}.json") with open(ch_json_file, 'w') as f: json.dump([], f) self.logger.info(f"Created empty JSON file for channel {ch} after error") except Exception as json_e: self.logger.error(f"Error creating empty JSON for channel {ch}: {json_e}") # Save the new annotation file if needed if save_to_annotations and new_annotations is not None and all_slow_waves: try: new_annotations.save(annotation_file_path) self.logger.info(f"Saved {len(all_slow_waves)} slow waves to new annotation file: {annotation_file_path}") except Exception as e: self.logger.error(f"Error saving annotation file: {e}") # Return all detected slow waves self.logger.info(f"Total slow waves detected across all channels: {len(all_slow_waves)}") return all_slow_waves
[docs] def export_slow_wave_parameters_to_csv(self, json_input, csv_file, export_params='all', frequency=None, ref_chan=None, grp_name='eeg', n_fft_sec=4, file_pattern=None,skip_empty_files=True): """ Calculate slow wave parameters from JSON files and export to CSV. Parameters ---------- json_input : str or list Path to JSON file, directory of JSON files, or list of JSON files csv_file : str Path to output CSV file export_params : dict or str Parameters to export. If 'all', exports all available parameters frequency : tuple or None Frequency range for power calculations ref_chan : list or None Reference channel(s) for parameter calculation n_fft_sec : int FFT window size in seconds for spectral analysis file_pattern : str or None Pattern to filter JSON files if json_input is a directory """ from wonambi.trans.analyze import event_params, export_event_params import glob # Clean memory first self.clean_memory() self.logger.info("Calculating slow wave parameters for CSV export...") # Load slow waves from JSON file(s) json_files = [] if file_pattern: all_json_files = glob.glob(os.path.join(json_input, "*.json")) json_files = [f for f in all_json_files if f"{file_pattern}_" in os.path.basename(f) or f"{file_pattern}." in os.path.basename(f)] else: json_files = glob.glob(os.path.join(json_input, "*.json")) self.logger.info(f"Found {len(json_files)} JSON files matching pattern: {file_pattern}") # Load slow waves from JSON files all_slow_waves = [] empty_channels = [] for file in json_files: try: with open(file, 'r') as f: slow_waves = json.load(f) if isinstance(slow_waves, list): if len(slow_waves) > 0: all_slow_waves.extend(slow_waves) else: # Extract channel name from filename filename = os.path.basename(file) parts = filename.split('_') if len(parts) > 1: chan = parts[-1].replace('.json', '') empty_channels.append(chan) self.logger.info(f"File {file} contains an empty list (no slow waves)") else: self.logger.warning(f"Warning: Unexpected format in {file}") self.logger.info(f"Loaded {len(slow_waves) if isinstance(slow_waves, list) else 0} slow waves from {file}") except Exception as e: self.logger.error(f"Error loading {file}: {e}") if not all_slow_waves: self.logger.info("No slow waves found in the input files") # Create an empty CSV file with header to indicate processing was done if empty_channels and not skip_empty_files: try: with open(csv_file, 'w', newline='') as outfile: writer = csv.writer(outfile) writer.writerow(["No slow waves were detected in the following channels:"]) for chan in empty_channels: writer.writerow([chan]) self.logger.info(f"Created empty CSV file at {csv_file}") except Exception as e: self.logger.error(f"Error creating empty CSV: {e}") return None # Get frequency band from slow waves if not provided if frequency is None: try: if 'freq_range' in all_slow_waves[0]: freq_range = all_slow_waves[0]['freq_range'] if isinstance(freq_range, list) and len(freq_range) == 2: frequency = tuple(freq_range) elif isinstance(freq_range, str) and '-' in freq_range: freq_parts = freq_range.split('-') frequency = (float(freq_parts[0].replace('Hz', '').strip()), float(freq_parts[1].replace('Hz', '').strip())) self.logger.info(f"Using frequency range from JSON: {frequency}") except: frequency = (0.1, 4.0) # Default for slow waves self.logger.info(f"Using default frequency range: {frequency}") # Get sampling frequency from dataset try: s_freq = self.dataset.header['s_freq'] except: self.logger.error("Could not determine dataset sampling frequency") return None # Try to get recording start time recording_start_time = None try: if hasattr(self.dataset, 'header'): header = self.dataset.header if hasattr(header, 'start_time'): recording_start_time = header.start_time elif isinstance(header, dict) and 'start_time' in header: recording_start_time = header['start_time'] if recording_start_time: self.logger.info(f"Found recording start time: {recording_start_time}") else: self.logger.warning("Could not find recording start time in dataset header. Using relative time only.") except Exception as e: self.logger.error(f"Error getting recording start time: {e}") self.logger.warning("Using relative time only.") # Group slow waves by channel for more efficient processing waves_by_chan = {} for sw in all_slow_waves: chan = sw.get('chan') if chan not in waves_by_chan: waves_by_chan[chan] = [] waves_by_chan[chan].append(sw) self.logger.info(f"Grouped slow waves by {len(waves_by_chan)} channels") # Process each channel all_segments = [] # Load data for each channel and create segments for chan, waves in waves_by_chan.items(): self.logger.info(f"Processing {len(waves)} slow waves for channel {chan}") try: # Create time windows for slow waves wave_windows = [] for sw in waves: start_time = sw['start_time'] end_time = sw['end_time'] wave_windows.append((start_time, end_time)) # Create segments for i, (start_time, end_time) in enumerate(wave_windows): try: # Add buffer for FFT calculation buffer = 0.1 # 100ms buffer start_with_buffer = max(0, start_time - buffer) end_with_buffer = end_time + buffer # Read data data = self.dataset.read_data(chan=[chan], begtime=start_with_buffer, endtime=end_with_buffer) # Create segment seg = { 'data': data, 'name': 'slow_wave', 'start': start_time, 'end': end_time, 'n_stitch': 0, 'stage': waves[i].get('stage'), 'cycle': None, 'chan': chan, 'uuid': waves[i].get('uuid', str(i)) } all_segments.append(seg) except Exception as e: self.logger.error(f"Error creating segment for slow wave {start_time}-{end_time}: {e}") except Exception as e: self.logger.error(f"Error processing channel {chan}: {e}") if not all_segments: self.logger.error("No valid segments created for parameter calculation") return None self.logger.info(f"Created {len(all_segments)} segments for parameter calculation") # Calculate parameters n_fft = None if all_segments and n_fft_sec is not None: n_fft = int(n_fft_sec * s_freq) # Create temporary file temp_csv = csv_file + '.temp' try: # Calculate parameters self.logger.info(f"Calculating parameters with frequency band {frequency} and n_fft={n_fft}") params = event_params(all_segments, export_params, band=frequency, n_fft=n_fft) if not params: self.logger.info("No parameters calculated") return None # Export to temporary CSV self.logger.info("Exporting parameters to temporary file") export_event_params(temp_csv, params, count=None, density=None) # Store UUIDs uuid_dict = {} for i, segment in enumerate(all_segments): if 'uuid' in segment: uuid_dict[i] = segment['uuid'] # Process CSV self.logger.info("Processing CSV to remove summary rows and add HH:MM:SS format") with open(temp_csv, 'r', newline='') as infile, open(csv_file, 'w', newline='') as outfile: reader = csv.reader(infile) writer = csv.writer(outfile) # Read all rows all_rows = list(reader) # Find header row header_row_index = None start_time_index = None for i, row in enumerate(all_rows): if row and 'Start time' in row: header_row_index = i start_time_index = row.index('Start time') break if header_row_index is None or start_time_index is None: self.logger.error("Could not find 'Start time' column in CSV") with open(temp_csv, 'r') as src, open(csv_file, 'w') as dst: dst.write(src.read()) return params # Create filtered rows filtered_rows = [] # Add prefix rows for i in range(header_row_index): filtered_rows.append(all_rows[i]) # Add header row with additional columns header_row = all_rows[header_row_index].copy() header_row.insert(start_time_index + 1, 'Start time (HH:MM:SS)') if 'UUID' not in header_row: header_row.append('UUID') filtered_rows.append(header_row) # Add data rows for i in range(header_row_index + 5, len(all_rows)): row = all_rows[i] if not row: continue new_row = row.copy() # Add HH:MM:SS time format if len(row) > start_time_index: try: start_time_sec = float(row[start_time_index]) def sec_to_time(seconds): hours = int(seconds // 3600) minutes = int((seconds % 3600) // 60) sec = seconds % 60 return f"{hours:02d}:{minutes:02d}:{sec:06.3f}" # Calculate clock time if recording_start_time is not None: try: delta = datetime.timedelta(seconds=start_time_sec) event_time = recording_start_time + delta start_time_hms = event_time.strftime('%H:%M:%S.%f')[:-3] except: start_time_hms = sec_to_time(start_time_sec) else: start_time_hms = sec_to_time(start_time_sec) new_row.insert(start_time_index + 1, start_time_hms) except (ValueError, IndexError): new_row.insert(start_time_index + 1, '') else: new_row.insert(start_time_index + 1, '') # Add UUID segment_index = i - (header_row_index + 5) if segment_index in uuid_dict: new_row.append(uuid_dict[segment_index]) else: new_row.append('') filtered_rows.append(new_row) # Write filtered rows for row in filtered_rows: writer.writerow(row) # Remove temporary file try: os.remove(temp_csv) except: self.logger.info(f"Could not remove temporary file {temp_csv}") self.logger.info(f"Successfully exported to {csv_file} with HH:MM:SS time format") return params except Exception as e: self.logger.error(f"Error calculating parameters: {e}") import traceback traceback.print_exc() return None
[docs] def export_slow_wave_density_to_csv(self, json_input, csv_file, stage=None, file_pattern=None): """ Export slow wave statistics to CSV with both whole night and stage-specific densities. Parameters ---------- json_input : str or list Path to JSON file, directory of JSON files, or list of JSON files csv_file : str Path to output CSV file stage : str or list Sleep stage(s) to include file_pattern : str or None Pattern to filter JSON files """ import glob from collections import defaultdict # Load slow waves from JSON file(s) json_files = [] if file_pattern: all_json_files = glob.glob(os.path.join(json_input, "*.json")) json_files = [f for f in all_json_files if f"{file_pattern}_" in os.path.basename(f) or f"{file_pattern}." in os.path.basename(f)] else: json_files = glob.glob(os.path.join(json_input, "*.json")) self.logger.info(f"Found {len(json_files)} JSON files matching pattern: {file_pattern}") if not json_files: try: with open(csv_file, 'w', newline='') as outfile: writer = csv.writer(outfile) writer.writerow(["No JSON files found matching pattern:", file_pattern]) self.logger.info(f"Created empty CSV file at {csv_file}") except Exception as e: self.logger.error(f"Error creating empty CSV: {e}") return None # Prepare stages if stage is None: combined_stages = False stage_list = None elif isinstance(stage, list) and len(stage) > 1: combined_stages = True stage_list = stage combined_stage_name = "+".join(stage_list) self.logger.info(f"Calculating combined slow wave density for stages: {combined_stage_name}") elif isinstance(stage, list) and len(stage) == 1: combined_stages = False stage_list = [stage[0]] self.logger.info(f"Calculating slow wave density for stage: {stage_list[0]}") else: combined_stages = False stage_list = [stage] self.logger.info(f"Calculating slow wave density for stage: {stage}") # Load all slow waves all_slow_waves = [] for file in json_files: try: with open(file, 'r') as f: waves = json.load(f) all_slow_waves.extend(waves if isinstance(waves, list) else []) except Exception as e: self.logger.error(f"Error loading {file}: {e}") # Get stage durations epoch_duration_sec = 30 stage_counts = defaultdict(int) all_stages = self.annotations.get_stages() # Count epochs for s in all_stages: if s in ['Wake', 'NREM1', 'NREM2', 'NREM3', 'REM']: stage_counts[s] += 1 # Calculate durations stage_durations = {stg: count * epoch_duration_sec / 60 for stg, count in stage_counts.items()} total_duration_min = sum(stage_durations.values()) # Extract stages from slow waves if needed wave_stages = set() for sw in all_slow_waves: if not isinstance(sw, dict) or 'stage' not in sw: continue sw_stage = sw['stage'] if isinstance(sw_stage, list): for s in sw_stage: wave_stages.add(str(s)) else: wave_stages.add(str(sw_stage)) # Determine stages to process if stage is None: stages_to_process = sorted(wave_stages) combined_stages = False elif combined_stages: stages_to_process = [stage_list] else: stages_to_process = stage_list # Group slow waves by channel and stage waves_by_chan_stage = defaultdict(lambda: defaultdict(list)) waves_by_chan = defaultdict(list) for sw in all_slow_waves: if not isinstance(sw, dict): continue chan = sw.get('chan', sw.get('channel')) if not chan: continue waves_by_chan[chan].append(sw) if not combined_stages: if 'stage' in sw: sw_stages = sw['stage'] if isinstance(sw['stage'], list) else [sw['stage']] for sw_stage in sw_stages: sw_stage = str(sw_stage) waves_by_chan_stage[chan][sw_stage].append(sw) # Calculate statistics stage_channel_stats = defaultdict(dict) for chan in set(waves_by_chan.keys()): all_chan_waves = waves_by_chan[chan] for process_stage in stages_to_process: stage_waves = [] if combined_stages or (isinstance(process_stage, list) and len(process_stage) > 1): stages_to_include = process_stage if isinstance(process_stage, list) else stage_list stage_name_display = "+".join(stages_to_include) stages_set = set(str(s) for s in stages_to_include) stage_waves = [] seen_waves = set() for sw in all_chan_waves: if 'stage' not in sw: continue sw_stages = sw['stage'] if isinstance(sw['stage'], list) else [sw['stage']] sw_stages = set(str(s) for s in sw_stages) if sw_stages.intersection(stages_set) and id(sw) not in seen_waves: stage_waves.append(sw) seen_waves.add(id(sw)) stage_duration_min = sum(stage_durations.get(s, 0) for s in stages_to_include) else: s_str = str(process_stage) stage_waves = waves_by_chan_stage[chan].get(s_str, []) stage_name_display = process_stage stage_duration_min = stage_durations.get(s_str, 0) if len(stage_waves) == 0: continue # Calculate statistics stage_count = len(stage_waves) whole_night_count = len(all_chan_waves) stage_density = stage_count / stage_duration_min if stage_duration_min > 0 else 0 whole_night_density = whole_night_count / total_duration_min if total_duration_min > 0 else 0 # Calculate mean duration durations = [] for sw in stage_waves: if 'start_time' in sw and 'end_time' in sw: durations.append(sw['end_time'] - sw['start_time']) mean_duration = np.mean(durations) if durations else 0 # Store statistics key = tuple(process_stage) if isinstance(process_stage, list) else process_stage stage_channel_stats[key][chan] = { 'count': stage_count, 'stage_density': stage_density, 'whole_night_density': whole_night_density, 'mean_duration': mean_duration, 'stage_name_display': stage_name_display, 'stage_duration_min': stage_duration_min, } # Export to CSV with open(csv_file, 'w', newline='') as f: writer = csv.writer(f) # Add summary sections writer.writerow(['Whole Night Summary']) writer.writerow(['Total Recording Duration (min)', f'{total_duration_min:.2f}']) writer.writerow([]) writer.writerow(['Stage Duration Summary']) writer.writerow(['Stage', 'Duration (min)']) for stg in sorted(set(stage_durations.keys())): writer.writerow([stg, f"{stage_durations.get(stg, 0):.2f}"]) if combined_stages: combined_duration = sum(stage_durations.get(s, 0) for s in stage_list) writer.writerow([combined_stage_name, f"{combined_duration:.2f}"]) writer.writerow([]) # Process each stage for process_stage in stages_to_process: key = tuple(process_stage) if isinstance(process_stage, list) else process_stage if key not in stage_channel_stats: continue any_chan = next(iter(stage_channel_stats[key].keys())) stage_name_display = stage_channel_stats[key][any_chan]['stage_name_display'] writer.writerow([f"Sleep Stage: {stage_name_display}"]) writer.writerow([ 'Channel', 'Count', f'Density in {stage_name_display} (events/min)', 'Whole Night Density (events/min)', 'Mean Duration (s)' ]) for chan in sorted(stage_channel_stats[key].keys()): stats = stage_channel_stats[key][chan] writer.writerow([ chan, stats['count'], f"{stats['stage_density']:.4f}", f"{stats['whole_night_density']:.4f}", f"{stats['mean_duration']:.4f}" ]) writer.writerow([]) self.logger.info(f"Exported slow wave statistics to {csv_file}") return dict(stage_channel_stats)
[docs] def save_detection_summary(self, output_dir, method, parameters, results_summary): """ Save a comprehensive summary of detection parameters and results. Parameters ---------- output_dir : str Directory to save the summary method : str Detection method used parameters : dict All parameters used for detection results_summary : dict Summary of detection results """ try: import datetime timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') summary_file = os.path.join(output_dir, f"detection_summary_{method}_{timestamp}.json") summary_data = { 'detection_method': method, 'parameters': parameters, 'results': results_summary, 'timestamp': datetime.datetime.now().isoformat(), 'software_version': 'TurtleWave hdEEG GUI' } with open(summary_file, 'w') as f: json.dump(summary_data, f, indent=2) self.logger.info(f"Saved detection summary to: {summary_file}") return summary_file except Exception as e: self.logger.error(f"Error saving detection summary: {e}") return None
############# SQLite Database Initialization and Import Functions #############
[docs] def initialize_sqlite_database(self, db_path='neural_events.db'): """ Create SQLite database optimized for storing calculated event parameters from event_params() function. Parameters ---------- db_path : str Path to SQLite database file Returns ------- str Path to created database """ import sqlite3 import os # If db_path is a directory, append the default filename if os.path.isdir(db_path): db_path = os.path.join(db_path, 'neural_events.db') self.logger.info(f"Database path was a directory, using: {db_path}") # Create directory for database if it doesn't exist db_dir = os.path.dirname(db_path) if db_dir and not os.path.exists(db_dir): os.makedirs(db_dir, exist_ok=True) self.logger.info(f"Created directory for database: {db_dir}") # Check if database exists db_exists = os.path.exists(db_path) # Define the database initialization operation def init_db(conn): cursor = conn.cursor() # Main events table with common fields across all event types conn.execute(''' CREATE TABLE IF NOT EXISTS events ( uuid TEXT PRIMARY KEY, event_type TEXT, -- 'spindle', 'slow_wave', 'ripple', etc. channel TEXT, -- Basic temporal properties start_time REAL, end_time REAL, duration REAL, start_time_hms TEXT, -- formatted time (HH:MM:SS) stage TEXT, cycle TEXT, -- sleep cycle method TEXT, -- Frequency band information freq_band TEXT, -- Full text representation (e.g. "0.5-3Hz") freq_lower REAL, -- Lower bound of frequency band (e.g. 0.5) freq_upper REAL, -- Upper bound of frequency band (e.g. 3.0) -- Amplitude metrics min_amp REAL, -- minimum amplitude max_amp REAL, -- maximum amplitude peak2peak_amp REAL, -- peak-to-peak amplitude -- Processing metadata processing_timestamp TEXT, n_fft_sec INTEGER, CONSTRAINT event_chan_time UNIQUE (event_type, channel, start_time, method, freq_lower, freq_upper, stage) )''') # Create tracking table for batch processing conn.execute(''' CREATE TABLE IF NOT EXISTS processing_status ( channel TEXT, event_type TEXT, json_file TEXT, processed BOOLEAN DEFAULT 0, attempts INTEGER DEFAULT 0, last_attempt_time TEXT, success BOOLEAN DEFAULT 0, error_message TEXT, PRIMARY KEY (channel, event_type) )''') # Create indexes for efficient querying conn.execute('CREATE INDEX IF NOT EXISTS idx_event_type ON events(event_type)') conn.execute('CREATE INDEX IF NOT EXISTS idx_channel ON events(channel)') conn.execute('CREATE INDEX IF NOT EXISTS idx_timerange ON events(start_time, end_time)') conn.execute('CREATE INDEX IF NOT EXISTS idx_stage ON events(stage)') conn.commit() # If database didn't exist, log creation if not db_exists: self.logger.info(f"Created new database at: {db_path}") return db_path # Use the safe database operation return self._safe_database_operation(db_path, init_db)
def _safe_database_operation(self, db_path, operation_func): """Safely perform a database operation with proper connection handling""" import sqlite3 conn = None try: conn = sqlite3.connect(db_path) result = operation_func(conn) return result except Exception as e: self.logger.error(f"Database error: {e}") raise finally: if conn: conn.close()
[docs] def import_parameters_csv_to_database(self, csv_file, db_path, append=True): """ Import event parameters from an existing CSV file into SQLite database. Supports multiple event types and incremental updates. Parameters ---------- csv_file : str Path to existing parameters CSV file db_path : str Path to SQLite database append : bool If True, adds to existing database without replacing existing entries If False, replaces any existing entries with the same UUID Returns ------- dict Summary of the operation with counts of added, updated, and skipped rows """ import sqlite3 import pandas as pd import os import glob self.clean_memory() # Initialize database if needed if not os.path.exists(db_path): self.initialize_sqlite_database(db_path) # Check if the file exists if not os.path.exists(csv_file): self.logger.error(f"CSV file not found: {csv_file}") return {"error": "CSV file not found", "added": 0, "updated": 0, "skipped": 0} # Track statistics stats = { "added": 0, "updated": 0, "skipped": 0 } # Read the CSV file self.logger.info(f"Reading parameters from CSV: {csv_file}") try: # First determine how many rows to skip (header plus statistics) with open(csv_file, 'r') as f: lines = f.readlines() # Find the header row (contains 'Start time') header_row = None for i, line in enumerate(lines): if 'Start time' in line: header_row = i break if header_row is None: self.logger.error("Could not find header row in CSV") return {"error": "Could not find header row", "added": 0, "updated": 0, "skipped": 0} # Check if there are statistic rows after the header has_stat_rows = False if header_row + 1 < len(lines): next_line = lines[header_row + 1] # Check if the next line starts with "Mean" or contains statistical summaries if next_line.strip().startswith('Mean') or 'Mean' in next_line: has_stat_rows = True # Skip header row and 4 statistic rows skiprows = header_row + 4 if has_stat_rows else header_row # Read the CSV, skipping header and statistics df = pd.read_csv(csv_file, skiprows=skiprows) if df.empty: self.logger.warning("CSV file contains no data rows") return {"error": "Empty CSV file", "added": 0, "updated": 0, "skipped": 0} self.logger.info(f"Read {len(df)} parameter rows from CSV") # Define database operation function def process_csv_data(conn): cursor = conn.cursor() # Determine event type from CSV filename or content event_type = "slow_wave" # Default filename = os.path.basename(csv_file).lower() if 'slow_wave' in filename or 'slowwave' in filename or 'sw' in filename: event_type = "slow_wave" elif 'spindle' in filename: event_type = "spindle" # Override event_type if 'Event type' column exists in CSV if 'Event type' in df.columns: # Use the first non-null value in the Event type column event_types = df['Event type'].dropna() if len(event_types) > 0: event_type = event_types.iloc[0] self.logger.info(f"Importing parameters for event type: {event_type}") # Map column names from CSV to database columns column_mapping = { 'Start time': 'start_time', 'Start time (HH:MM:SS)': 'start_time_hms', 'End time': 'end_time', 'Stage': 'stage', 'Cycle': 'cycle', 'Event type': 'event_type', 'Channel': 'channel', 'Duration (s)': 'duration', 'Min. amplitude (uV)':'min_amp', 'Max. amplitude (uV)': 'max_amp', 'Peak-to-peak amplitude (uV)': 'peak2peak_amp', #'RMS (uV)': 'rms', #'Power (uV^2)': 'power', #'Peak power frequency (Hz)': 'peak_power_freq', #'Energy (uV^2s)': 'energy', #'Peak energy frequency': 'peak_energy_freq', 'UUID': 'uuid' } # Create a list of columns that exist in the dataframe existing_columns = [] db_columns = [] for csv_col, db_col in column_mapping.items(): if csv_col in df.columns: existing_columns.append(csv_col) db_columns.append(db_col) # Add processing timestamp import datetime now = datetime.datetime.now().isoformat() df['processing_timestamp'] = now existing_columns.append('processing_timestamp') db_columns.append('processing_timestamp') # Extract frequency band from filename if possible filename = os.path.basename(csv_file) freq_band = "unknown" # Try to extract frequency from filename (e.g., sw_parameters_Staresina2015_0.3-2.0Hz_NREM2NREM3.csv) if "_" in filename and "Hz" in filename: parts = filename.split('_') for part in parts: if "Hz" in part: freq_band = part try: # Handle formats like "9-12Hz" or "9.0-12.0Hz" freq_parts = freq_band.replace("Hz", "").split("-") if len(freq_parts) == 2: freq_lower = float(freq_parts[0]) freq_upper = float(freq_parts[1]) except ValueError: self.logger.warning(f"Could not parse frequency bounds from {freq_band}") break df['freq_band'] = freq_band df['freq_lower'] = freq_lower df['freq_upper'] = freq_upper existing_columns.append('freq_band') existing_columns.append('freq_lower') existing_columns.append('freq_upper') db_columns.append('freq_band') db_columns.append('freq_lower') db_columns.append('freq_upper') # Extract method from filename if possible method = "unknown" if "_" in filename: parts = filename.split('_') if len(parts) > 2: # Typically the format is sw_parameters_METHOD_freq_stages.csv method = parts[2] df['method'] = method existing_columns.append('method') db_columns.append('method') # Set event_type from our detection df['event_type'] = event_type if 'event_type' not in db_columns: existing_columns.append('event_type') db_columns.append('event_type') # Check for UUID column, which is essential for avoiding duplicates uuid_col = 'UUID' if 'UUID' in df.columns else 'uuid' if 'uuid' in df.columns else None # If no UUID column, create one if uuid_col is None: self.logger.warning("No UUID column found, creating UUIDs based on channel and time") import uuid df['uuid'] = [ str(uuid.uuid4()) for _ in range(len(df)) ] uuid_col = 'uuid' existing_columns.append('uuid') db_columns.append('uuid') # Check if the required columns for uniqueness constraint exist if 'Channel' not in df.columns or 'Start time' not in df.columns: self.logger.warning("Missing required columns for uniqueness check") # Pre-check existing events by unique constraint (event_type, channel, start_time, method) # rather than just UUID to avoid constraint violations existing_events = set() if append and 'Channel' in df.columns and 'Start time' in df.columns: # Get all unique combinations of event_type, channel, start_time channels = df['Channel'].astype(str).tolist() start_times = df['Start time'].astype(float).tolist() # Build a query to get existing events matching these combinations query_parts = [] query_params = [] for i in range(len(channels)): freq_lower = df['freq_lower'].iloc[i] if 'freq_lower' in df.columns else None freq_upper = df['freq_upper'].iloc[i] if 'freq_upper' in df.columns else None stage = df['Stage'].iloc[i] if 'Stage' in df.columns else None query_parts.append("(event_type = ? AND channel = ? AND start_time = ? AND method = ? AND freq_lower = ? AND freq_upper = ? AND stage = ?)") query_params.extend([event_type, channels[i], start_times[i],method, freq_lower, freq_upper, stage]) if query_parts: query = f"SELECT event_type, channel, start_time, method, freq_lower, freq_upper, stage FROM events WHERE {' OR '.join(query_parts)}" cursor.execute(query, query_params) for row in cursor.fetchall(): # Create a tuple of (event_type, channel, start_time. method) to check against existing_events.add((row[0], row[1], row[2], row[3], row[4], row[5], row[6])) self.logger.info(f"Found {len(existing_events)} existing entries matching event type, channel, and start time") # Mark rows that exist in the database based on the uniqueness constraint df['exists_in_db'] = df.apply( lambda row: ( event_type, str(row.get('Channel', '')), float(row.get('Start time', 0)), method, row.get('freq_lower', None), row.get('freq_upper', None), str(row.get('Stage','')) ) in existing_events, axis=1 ) # # If appending, we need to check which rows already exist in the database # if append and uuid_col: # # Get all UUIDs from the dataframe # all_uuids = df[uuid_col].astype(str).tolist() # # Check which UUIDs already exist in the database # placeholders = ','.join(['?' for _ in all_uuids]) # cursor.execute(f"SELECT uuid FROM events WHERE uuid IN ({placeholders})", all_uuids) # existing_uuids = {row[0] for row in cursor.fetchall()} # self.logger.info(f"Found {len(existing_uuids)} existing entries in database") # # Mark rows that already exist in the database # df['exists_in_db'] = df[uuid_col].apply(lambda x: str(x) in existing_uuids) # else: # # If not appending, mark all rows as not existing # df['exists_in_db'] = False # Process each row based on whether it exists and append mode for _, row in df.iterrows(): if isinstance(row['Stage'], list): row['Stage'] = '+'.join(row['Stage']) elif isinstance(row['Stage'], str) and '[' in row['Stage']: # Sometimes stage might be a string representation of a list like "['NREM2', 'NREM3']" # Try to convert it to a proper list then join try: import ast stage_list = ast.literal_eval(row['Stage']) if isinstance(stage_list, list): row['Stage'] = ''.join(stage_list) except: # If conversion fails, keep as is pass # Skip existing rows when in append mode if append and row['exists_in_db']: stats["skipped"] += 1 continue values = [row[col] if col in row else None for col in existing_columns] # Handle NaN values for i, val in enumerate(values): # Check if value is NaN (using pandas or numpy's isnan) if pd.isna(val) or (hasattr(val, 'isnan') and val.isnan()): values[i] = None # Convert NaN to None (which becomes NULL in SQLite) try: if append and row['exists_in_db']: # Skip existing rows when in append mode stats["skipped"] += 1 continue if not append and row['exists_in_db']: # Update existing row when not in append mode update_columns = [col for col in db_columns if col != 'uuid'] update_values = [val for i, val in enumerate(values) if db_columns[i] != 'uuid'] # Update based on the unique constraint, not just UUID cursor.execute(f""" UPDATE events SET {', '.join([f'{col} = ?' for col in update_columns])} WHERE event_type = ? AND channel = ? AND start_time = ? AND method = ? AND freq_lower = ? AND freq_upper = ? AND stage = ? """, update_values + [ event_type, row.get('Channel', ''), row.get('Start time', 0), method, row.get('freq_lower', None), row.get('freq_upper', None), str(row.get('Stage', '')) ]) stats["updated"] += 1 else: # Insert new row - use REPLACE to handle any constraint violations cursor.execute(f""" INSERT OR REPLACE INTO events ({', '.join(db_columns)}) VALUES ({', '.join(['?' for _ in db_columns])}) """, values) stats["added"] += 1 except Exception as e: self.logger.error(f"Error processing row: {e}") stats["skipped"] += 1 conn.commit() self.logger.info(f"Database updated: {stats['added']} added, {stats['updated']} updated, {stats['skipped']} skipped") # Update processing status #cursor.execute("PRAGMA table_info(processing_status)") #columns = cursor.fetchall() #print("Columns in processing_status table:", columns) # Update processing status with handling for both channels with events and empty channels if 'Channel' in df.columns: processed_channels = set(df['Channel'].unique()) # Add channels that have events in the CSV for channel in processed_channels: cursor.execute(''' INSERT OR REPLACE INTO processing_status (channel, event_type, processed, success, attempts, last_attempt_time) VALUES (?, ?, 1, 1, 1, datetime('now')) ''', (channel,event_type)) # Try to identify empty channels from JSON filenames # Note: This assumes the CSV file name contains information to identify related JSON files csv_basename = os.path.basename(csv_file) parts = csv_basename.split('_') if len(parts) >= 3: # For CSVs like: spindle_parameters_Ferrarelli2007_9-12Hz_NREM2NREM3.csv # Matching JSONs like: spindles_Ferrarelli2007_9-12Hz_NREM2NREM3_E101.json # Extract the method and frequency-stage parts method = parts[2] # Ferrarelli2007 freq_stage = parts[3:] # ['9-12Hz', 'NREM2NREM3'] freq_stage_str = '_'.join(freq_stage).replace('.csv', '') # Construct pattern to find related JSON files json_pattern = f"{event_type}s_{method}_{freq_stage_str}_*" # Find JSON files matching the pattern json_dir = os.path.dirname(csv_file) all_json_files = glob.glob(os.path.join(json_dir, f"{json_pattern}.json")) self.logger.info(f"Looking for JSON files matching pattern: {json_pattern}.json") self.logger.info(f"Found {len(all_json_files)} matching JSON files") # Extract channel names from JSON files empty_channels = set() for file in all_json_files: try: # Extract channel name from filename # Assuming format like "spindles_method_freq_stage_CHANNELNAME.json" channel_name = os.path.basename(file).split('_')[-1].replace('.json', '') # Skip if channel already in processed_channels if channel_name in processed_channels: continue # Read JSON file to check if it's empty with open(file, 'r') as f: content = json.load(f) # If JSON file contains an empty array, add to empty_channels if isinstance(content, list) and len(content) == 0: empty_channels.add(channel_name) self.logger.info(f"Found empty JSON file for channel: {channel_name}") except Exception as e: self.logger.warning(f"Error checking JSON file {file}: {e}") # Add empty channels to processing_status for channel in empty_channels: cursor.execute(''' INSERT OR REPLACE INTO processing_status (channel, event_type, processed, success, attempts, last_attempt_time, error_message) VALUES (?, ?, 1, 1, 1, datetime('now'), 'No events detected') ''', (channel,event_type)) if empty_channels: self.logger.info(f"Recorded {len(empty_channels)} channels with no events: {', '.join(empty_channels)}") # Add empty channels count to stats stats["empty_channels"] = len(empty_channels) conn.commit() # Get total count cursor.execute("SELECT COUNT(*) FROM events") total_count = cursor.fetchone()[0] self.logger.info(f"Total parameters in database: {total_count}") conn.close() return stats # Use the safe database operation return self._safe_database_operation(db_path, process_csv_data) except Exception as e: self.logger.error(f"Error processing CSV: {e}") import traceback traceback.print_exc() return {"error": str(e), "added": 0, "updated": 0, "skipped": 0}