Source code for turtlewave_hdEEG.pacprocessor

##pacprocessor.py


"""
pac_processor.py
A class for phase-amplitude coupling (PAC) analysis for high-density EEG data.
Based on the OCTOPUS method from the seapipe package.
"""

import os
import sys
import numpy as np
import time
import json
import csv
import logging
from wonambi.dataset import Dataset
from wonambi.attr import Annotations
from wonambi.trans import fetch
from copy import deepcopy
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
import pandas as pd
from datetime import datetime


[docs] class ParalPAC: """ A class for parallel detection and analysis of phase-amplitude coupling (PAC) across multiple channels of high-density EEG data. """ def __init__(self, dataset, annotations=None, rootpath=None, log_level=logging.INFO, log_file=None): """ Initialize the ParalPAC object. Parameters ---------- dataset : Dataset Dataset object containing EEG data annotations : Annotations Annotations object for storing and retrieving events rootpath : str Root path for input/output operations log_level : int Logging level (e.g., logging.DEBUG, logging.INFO) log_file : str or None Path to log file. If None, logs to console only. """ self.dataset = dataset self.annotations = annotations self.rootpath = rootpath if rootpath else os.path.dirname(os.path.dirname(dataset.filename)) # Setup logging self.logger = self._setup_logger(log_level, log_file) # Initialize the tracking dictionary self.tracking = {'event_pac': {}} def _setup_logger(self, log_level, log_file=None): """Set up a logger for the PAC processor.""" # Create a logger logger = logging.getLogger('turtlewave_hdEEG.pacprocessor') logger.setLevel(log_level) # Create formatter formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') # Create console handler console_handler = logging.StreamHandler() console_handler.setFormatter(formatter) logger.addHandler(console_handler) # Create file handler if log_file specified if log_file: file_handler = logging.FileHandler(log_file) file_handler.setFormatter(formatter) logger.addHandler(file_handler) return logger
[docs] def pac_method(self, method, surrogate, correction, list_methods=False): """ Format the method and corrections to be applied through Tensorpac. Adapted from OCTOPUS module. Parameters ---------- method : int PAC method number surrogate : int Surrogate method number correction : int Correction method number list_methods : bool If True, return a list of method descriptions Returns ------- tuple or list Either a tuple of (method, surrogate, correction) or a list of descriptions """ # Calculate Coupling Strength (idpac) methods = {1: 'Mean Vector Length (MVL) [Canolty et al. 2006 (Science)]', 2: 'Modulation Index (MI) [Tort 2010 (J Neurophys.)]', 3: 'Heights Ratio (HR) [Lakatos 2005 (J Neurophys.)]', 4: 'ndPAC [Ozkurt 2012 (IEEE)]', 5: 'Phase-Locking Value (PLV) [Penny 2008 (J. Neuro. Meth.), Lachaux 1999 (HBM)]', 6: 'Gaussian Copula PAC (GCPAC) `Ince 2017 (HBM)`'} surrogates = {0: 'No surrogates', 1: 'Swap phase / amplitude across trials [Tort 2010 (J Neurophys.)]', 2: 'Swap amplitude time blocks [Bahramisharif 2013 (J. Neurosci.) ]', 3: 'Time lag [Canolty et al. 2006 (Science)]'} corrections = {0: 'No normalization', 1: 'Subtract the mean of surrogates', 2: 'Divide by the mean of surrogates', 3: 'Subtract then divide by the mean of surrogates', 4: 'Z-score'} if list_methods: return [methods, surrogates, corrections] else: return (method, surrogate, correction)
[docs] def analyze_pac(self, chan=None, ref_chan=None, grp_name='eeg', stage=None, rater=None, reject_artf=['Artefact', 'Arousal'], cycle_idx=None, cat=(1,1,1,0), nbins=18, phase_freq=(0.5, 1.25), amp_freq=(11, 16), idpac=(2, 3, 4), min_dur=1, adap_bands_phase='Fixed', adap_bands_amplitude='Fixed', filter_opts=None, event_opts=None, invert=False, use_detected_events=True, event_type='slow_wave', pair_with_spindles=False, time_window=0.5, db_path=None, out_dir=None, progress=False): """ Analyze phase-amplitude coupling (PAC) in the dataset. Parameters ---------- chan : list or str Channels to analyze ref_chan : list or str Reference channel(s) for re-referencing grp_name : str Group name for channel selection stage : list or str Sleep stage(s) to analyze rater : str Rater name for annotations reject_artf : list Event types to reject cycle_idx : list or None Sleep cycle indices to include cat : tuple Category specification for data selection nbins : int Number of phase bins phase_freq : tuple Frequency range for phase signal amp_freq : tuple Frequency range for amplitude signal idpac : tuple PAC method settings (method, surrogate, correction) min_dur : float Minimum event duration in seconds adap_bands_phase : str Type of frequency band adaptation for phase adap_bands_amplitude : str Type of frequency band adaptation for amplitude filter_opts : dict Signal filtering options event_opts : dict Event processing options invert : bool Whether to invert signal polarity use_detected_events : bool Whether to use detected events for PAC analysis event_type : str Type of events to use ('slow_wave' or 'spindle') pair_with_spindles : bool If True and event_type is 'slow_wave', will pair slow waves with spindles time_window : float Time window (in seconds) to search for spindles around slow waves db_path : str Path to the SQLite database containing events out_dir : str Output directory for results progress : bool Whether to show progress bar Returns ------- dict Dictionary containing PAC results """ from tensorpac import Pac import sys import sqlite3 # Set up logger logger = self.logger # Get method descriptions pac_list = self.pac_method(0, 0, 0, list_methods=True) methods = pac_list[0] surrogates = pac_list[1] corrections = pac_list[2] # Set up tracking tracking = self.tracking flag = 0 # Set up default filter options if not provided # https://etiennecmb.github.io/tensorpac/generated/tensorpac.Pac.html?highlight=cycle#tensorpac.Pac.cycle if filter_opts is None: filter_opts = { 'notch': True, 'notch_freq': 50, 'notch_harmonics': True, 'bandpass': True, 'highpass': 0.1, 'lowpass': 45, 'laplacian': False, 'dcomplex': 'hilbert', 'filtcycle': [3, 6], 'width': 7 } # Set up default event options if not provided if event_opts is None: event_opts = { 'buffer': 1.0 # Buffer in seconds } logger.info("") logger.info(""" ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ___ ___ _____ ___________ _ _____ / _ \/ __/_ _/__ / __/ _ | |/|/ / _ \\ / // / _/ / / _/ /_\ \/ __ | / ___/ /____/___/ /_/ /___/___/_/ |_/_/|_/_/ Phase-Amplitude Coupling Analysis ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ """) logger.info(f"Method: {methods[idpac[0]]}") logger.info(f"Surrogate: {surrogates[idpac[1]]}") logger.info(f"Correction: {corrections[idpac[2]]}") # Log filtering options logger.info(f"Using {adap_bands_phase} bands for phase frequency") logger.info(f"Using {adap_bands_amplitude} bands for amplitude frequency") if filter_opts['notch']: logger.info(f"Applying notch filtering: {filter_opts['notch_freq']} Hz") if filter_opts['notch_harmonics']: logger.info("Applying notch harmonics filtering") if filter_opts['bandpass']: logger.info(f"Applying bandpass filtering: {filter_opts['highpass']} - {filter_opts['lowpass']} Hz") if filter_opts['laplacian']: logger.info("Applying Laplacian filtering") # 1. Check directories if out_dir: base_out_dir = out_dir else: base_out_dir = os.path.join(self.rootpath, "wonambi", "pac_results") os.makedirs(base_out_dir, exist_ok=True) logger.info(f"Using base output directory: {base_out_dir}") # 2. Process channel input if isinstance(chan, str): chan = [chan] # 3. Process stage input if isinstance(stage, str): stage = [stage] # 4. Determine database path if db_path is None: db_path = os.path.join(self.rootpath, "wonambi", "neural_events.db") logger.info(f"Using default database path: {db_path}") if not os.path.exists(db_path): logger.error(f"Database file not found: {db_path}") return None # 5. Begin channel processing for c, ch in enumerate(chan): chan_results = {} logger.info(f"Processing channel: {ch}") # Prepare output filename if adap_bands_phase == 'Fixed': phadap = '-fixed' else: phadap = '-adap' if adap_bands_amplitude == 'Fixed': ampadap = '-fixed' else: ampadap = '-adap' phaname1 = round(phase_freq[0], 2) phaname2 = round(phase_freq[1], 2) ampname1 = round(amp_freq[0], 2) ampname2 = round(amp_freq[1], 2) freqs = f'pha-{phaname1}-{phaname2}Hz{phadap}_amp-{ampname1}-{ampname2}Hz{ampadap}' # Extract method information before creating output directories sw_method = event_opts.get('sw_method', 'unknown') if event_opts else 'unknown' spindle_method = event_opts.get('spindle_method', 'unknown') if event_opts else 'unknown' # Create a method-specific output directory stage_str = ''.join(stage) if isinstance(stage, list) else str(stage) # Use consistent directory structure for all output files if pair_with_spindles and event_type == 'slow_wave': # For slow wave-spindle pairing, include both methods method_dir = f"{sw_method}_paired_{spindle_method}" else: # For single event type analysis method_dir = sw_method if event_type == 'slow_wave' else spindle_method # Create the full output directory path method_out_dir = os.path.join(base_out_dir, method_dir, stage_str) os.makedirs(method_out_dir, exist_ok=True) logger.info(f"Using method-specific output directory: {method_out_dir}") # Create output filenames using the method-specific directory if pair_with_spindles and event_type == 'slow_wave': outputfile = f'{method_out_dir}/{ch}_slowwave_spindle_coupling_{freqs}_pac_parameters.csv' else: outputfile = f'{method_out_dir}/{ch}_{event_type}_{freqs}_pac_parameters.csv' # 6. Fetch data segments try: logger.info(f"Fetching data segments for {ch}") if use_detected_events: # Get events from SQLite database logger.info(f"Using detected {event_type} events from database") # Connect to database try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # Construct SQL query based on parameters if event_type == 'slow_wave': # Get slow waves from the database query = """ SELECT uuid, channel, start_time, end_time, duration, stage, method, freq_lower, freq_upper FROM events WHERE event_type = 'slow_wave' AND channel = ? """ # Initialize params list params = [ch] # HERE IS MODIFY: Initialize params list with channel # Add method filter if specified if 'sw_method' in event_opts and event_opts['sw_method']: query += " AND method = ?" params.append(event_opts['sw_method']) # Add frequency range filter if specified if 'sw_freq_range' in event_opts and event_opts['sw_freq_range'] and len(event_opts['sw_freq_range']) == 2: query += " AND freq_lower >= ? AND freq_upper <= ?" params.extend(event_opts['sw_freq_range']) # Add stage filter if specified if stage and len(stage) > 0: placeholders = ', '.join(['?' for _ in stage]) query += f" AND stage IN ({placeholders})" params.extend(stage) #params = [ch] + stage #else: # params = [ch] # Execute query cursor.execute(query, params) slow_wave_events = cursor.fetchall() logger.info(f"Found {len(slow_wave_events)} slow wave events for channel {ch}") if pair_with_spindles: logger.info("Looking for slow wave-spindle pairs") # Initialize list for paired events paired_events = [] # For each slow wave, find spindles that occur within the time window for sw in slow_wave_events: sw_uuid, sw_chan, sw_start, sw_end, sw_dur, sw_stage, sw_method, sw_freq_lower, sw_freq_upper = sw # Define search window around the slow wave search_start = sw_start - time_window search_end = sw_end + time_window # Find spindles within this window spindle_query = """ SELECT uuid, channel, start_time, end_time, duration, stage, method, freq_lower, freq_upper FROM events WHERE event_type = 'spindle' AND channel = ? AND ((start_time >= ? AND start_time <= ?) OR (end_time >= ? AND end_time <= ?) OR (start_time <= ? AND end_time >= ?)) """ # Initialize spindle_params list with search parameters spindle_params = [ch, search_start, search_end, search_start, search_end, search_start, search_end] # Add method filter if specified if 'spindle_method' in event_opts and event_opts['spindle_method']: spindle_query += " AND method = ?" spindle_params.append(event_opts['spindle_method']) # Add frequency range filter if specified if 'spindle_freq_range' in event_opts and event_opts['spindle_freq_range'] and len(event_opts['spindle_freq_range']) == 2: spindle_query += " AND freq_lower >= ? AND freq_upper <= ?" spindle_params.extend(event_opts['spindle_freq_range']) cursor.execute(spindle_query, spindle_params) related_spindles = cursor.fetchall() if related_spindles: for sp in related_spindles: sp_uuid, sp_chan, sp_start, sp_end, sp_dur, sp_stage, sp_method, sp_freq_lower, sp_freq_upper = sp # Create a pair record paired_events.append({ 'sw_uuid': sw_uuid, 'sp_uuid': sp_uuid, 'channel': ch, 'sw_start': sw_start, 'sw_end': sw_end, 'sp_start': sp_start, 'sp_end': sp_end, 'stage': sw_stage, 'sw_method': sw_method, 'sp_method': sp_method }) logger.info(f"Found {len(paired_events)} slow wave-spindle pairs for channel {ch}") if len(paired_events) == 0: logger.warning(f"No slow wave-spindle pairs found for channel {ch}") continue # Create segments from paired events segments = [] for pair in paired_events: try: # Define analysis window that encompasses both events start_time = min(pair['sw_start'], pair['sp_start']) end_time = max(pair['sw_end'], pair['sp_end']) # Add buffer buffer = event_opts['buffer'] start_with_buffer = max(0, start_time - buffer) end_with_buffer = end_time + buffer # Read data data = self.dataset.read_data(chan=[ch], begtime=start_with_buffer, endtime=end_with_buffer) # Create segment seg = { 'data': data, 'name': 'sw_spindle_pair', 'start': start_time, 'end': end_time, 'n_stitch': 0, 'stage': pair['stage'], 'cycle': None, 'chan': ch, 'sw_uuid': pair['sw_uuid'], 'sp_uuid': pair['sp_uuid'] } segments.append(seg) except Exception as e: logger.error(f"Error creating segment for paired events: {e}") else: # Use slow waves directly segments = [] for sw in slow_wave_events: sw_uuid, sw_chan, sw_start, sw_end, sw_dur, sw_stage, sw_method, sw_freq_lower, sw_freq_upper = sw try: # Add buffer buffer = event_opts['buffer'] start_with_buffer = max(0, sw_start - buffer) end_with_buffer = sw_end + buffer # Read data data = self.dataset.read_data(chan=[ch], begtime=start_with_buffer, endtime=end_with_buffer) # Create segment seg = { 'data': data, 'name': 'slow_wave', 'start': sw_start, 'end': sw_end, 'n_stitch': 0, 'stage': sw_stage, 'cycle': None, 'chan': ch, 'uuid': sw_uuid } segments.append(seg) except Exception as e: logger.error(f"Error creating segment for slow wave {sw_uuid}: {e}") elif event_type == 'spindle': # Get spindles from the database query = """ SELECT uuid, channel, start_time, end_time, duration, stage, method, freq_lower, freq_upper FROM events WHERE event_type = 'spindle' AND channel = ? """ # Initialize params list params = [ch] # Initialize params list with channel # Add method filter if specified if 'spindle_method' in event_opts and event_opts['spindle_method']: query += " AND method = ?" params.append(event_opts['spindle_method']) # Add frequency range filter if specified if 'spindle_freq_range' in event_opts and event_opts['spindle_freq_range'] and len(event_opts['spindle_freq_range']) == 2: query += " AND freq_lower >= ? AND freq_upper <= ?" params.extend(event_opts['spindle_freq_range']) # Add stage filter if specified if stage and len(stage) > 0: placeholders = ', '.join(['?' for _ in stage]) query += f" AND stage IN ({placeholders})" params.extend(stage) # Execute query cursor.execute(query, params) spindle_events = cursor.fetchall() logger.info(f"Found {len(spindle_events)} spindle events for channel {ch}") if len(spindle_events) == 0: logger.warning(f"No spindle events found for channel {ch}") continue # Create segments from spindles segments = [] for sp in spindle_events: sp_uuid, sp_chan, sp_start, sp_end, sp_dur, sp_stage, sp_method, sp_freq_lower, sp_freq_upper = sp try: # Add buffer buffer = event_opts['buffer'] start_with_buffer = max(0, sp_start - buffer) end_with_buffer = sp_end + buffer # Read data data = self.dataset.read_data(chan=[ch], begtime=start_with_buffer, endtime=end_with_buffer) # Create segment seg = { 'data': data, 'name': 'spindle', 'start': sp_start, 'end': sp_end, 'n_stitch': 0, 'stage': sp_stage, 'cycle': None, 'chan': ch, 'uuid': sp_uuid } segments.append(seg) except Exception as e: logger.error(f"Error creating segment for spindle {sp_uuid}: {e}") else: logger.error(f"Unknown event type: {event_type}") continue # Close database connection conn.close() if not segments or len(segments) == 0: logger.warning(f"No valid segments created from database events for {ch}") continue logger.info(f"Created {len(segments)} segments for PAC analysis") except Exception as e: logger.error(f"Error accessing database: {e}") import traceback traceback.print_exc() continue else: # Use standard fetch for continuous data # NEED TO FIX STAGE ISN NREM2NREM3 <=============================== segments = fetch(self.dataset, self.annotations, cat=cat, evt_type=None, stage=stage, cycle=cycle_idx, buffer=event_opts['buffer']) # Read data for the channel segments.read_data(ch, ref_chan, grp_name=grp_name) if not segments or len(segments) == 0: logger.warning(f"No valid data segments found for {ch}") continue logger.info(f"Processing {len(segments)} data segments") # 6. Define PAC object pac = Pac(idpac=idpac, f_pha=phase_freq, f_amp=amp_freq, dcomplex=filter_opts['dcomplex'], cycle=filter_opts['filtcycle'], width=filter_opts['width'], n_bins=nbins, verbose='ERROR') # 7. Process segments # Initialize arrays for results ampbin = np.zeros((len(segments), nbins)) ms = int(np.ceil(len(segments)/50)) longamp = np.zeros((ms, 50), dtype=object) # Blocked amplitude series longpha = np.zeros((ms, 50), dtype=object) # Blocked phase series for s, seg in enumerate(segments): # Print progress if progress: j = s/len(segments) sys.stdout.write('\r') sys.stdout.write(f"Progress: [{'»' * int(50 * j):{50}s}] {int(100 * j)}%") sys.stdout.flush() # Extract data data = seg['data'] timeline = data.axis['time'][0] # Fix polarity of recording if needed dat = data()[0][0] if invert: dat = dat * -1 # Obtain phase signal pha = np.squeeze(pac.filter(data.s_freq, dat, ftype='phase')) if len(pha.shape) > 2: pha = np.squeeze(pha) # Obtain amplitude signal amp = np.squeeze(pac.filter(data.s_freq, dat, ftype='amplitude')) if len(amp.shape) > 2: amp = np.squeeze(amp) # Extract signal (minus buffer) nbuff = int(event_opts['buffer'] * data.s_freq) minlen = data.s_freq * min_dur if len(pha) >= 2 * nbuff + minlen: pha = pha[nbuff:-nbuff] amp = amp[nbuff:-nbuff] # Put data in blocks (for surrogate testing) longpha[s//50, s%50] = pha longamp[s//50, s%50] = amp # Calculate mean amplitude per phase bin ampbin[s, :] = self._mean_amp(pha, amp, nbins=nbins) # Clear progress line sys.stdout.write('\r') sys.stdout.flush() # 8. If number of events not divisible by block length, # pad incomplete final block with randomly resampled events rem = len(segments) % 50 if rem > 0: pads = 50 - rem for pad in range(pads): ran = np.random.randint(0, rem) longpha[-1, rem+pad] = longpha[-1, ran] longamp[-1, rem+pad] = longamp[-1, ran] # 9. Calculate Coupling Strength mi = np.zeros((longamp.shape[0], 1)) mi_pv = np.zeros((longamp.shape[0], 1)) for row in range(longamp.shape[0]): pha_data = np.zeros((1)) amp_data = np.zeros((1)) for col in range(longamp.shape[1]): pha_data = np.concatenate((pha_data, longpha[row, col])) amp_data = np.concatenate((amp_data, longamp[row, col])) pha_data = np.reshape(pha_data, (1, 1, len(pha_data))) amp_data = np.reshape(amp_data, (1, 1, len(amp_data))) mi[row] = pac.fit(pha_data, amp_data, n_perm=400, random_state=5, verbose=False)[0][0] mi_pv[row] = pac.infer_pvalues(p=0.95, mcp='fdr')[0][0] # 10. Calculate preferred phase # Normalize amplitude by sum (to get probability distribution) ampbin = ampbin / ampbin.sum(-1, keepdims=True) ampbin = ampbin.squeeze() # Remove NaN trials ampbin = ampbin[~np.isnan(ampbin[:, 0]), :] ab = ampbin # Create bins for preferred phase vecbin = np.zeros(nbins) width = 2 * np.pi / nbins for n in range(nbins): vecbin[n] = n * width + width / 2 # Calculate circular statistics from scipy.stats import circmean, circvar # Find bin with max amplitude for each trial ab_pk = np.argmax(ab, axis=1) # Convert to angles angles = vecbin[ab_pk] # Calculate mean direction (theta) & mean vector length (rad) theta = circmean(angles) theta_deg = np.degrees(theta) if theta_deg < 0: theta_deg += 360 # Calculate circular variance (1 - R) circ_var = circvar(angles) rad = 1 - circ_var # Mean resultant length # Take mean across all segments/events ma = np.nanmean(ab, axis=0) # Correlation between mean amplitudes and phase-giving sine wave sine = np.sin(np.linspace(-np.pi, np.pi, nbins)) sine = np.interp(sine, (sine.min(), sine.max()), (ma.min(), ma.max())) from scipy.stats import pearsonr rho, pv1 = pearsonr(ma, sine) # # Rayleigh test for non-uniformity of circular data ppha = vecbin[ab.argmax(axis=-1)] # phase in radians n = len(ppha) r = np.abs(np.sum(np.exp(1j * ppha))) / n z = n * r**2 # Get test statistic from the rayleigh_test function pv2 = np.exp(-z) # Get p-value directly from the rayleigh_test function # 11. Export and save data # Save binned amplitudes to numpy file amp_file = outputfile.split('_pac_parameters.csv')[0] + '_mean_amps' np.save(amp_file, ab) # Save CFC metrics to dataframe d = pd.DataFrame([ np.mean(pac.pac), np.mean(mi), np.median(mi_pv), theta, theta_deg, rad, rho, z, pv2 ]).transpose() d.columns = [ 'mi_raw', 'mi_norm', 'median_mi_pval', 'preferred_phase_rad', 'preferred_phase_deg', 'mean_vector_length', 'rho', 'rayleigh_z', 'rayleigh_p' ] d.to_csv(outputfile, sep=',') logger.info(f"Saved PAC results to {outputfile}") logger.info(f"Saved mean amplitudes to {amp_file}.npy") # Store results in channel_results chan_results = { 'mi_raw': float(np.mean(pac.pac)), 'mi_norm': float(np.mean(mi)), 'pval': float(np.median(mi_pv)), 'preferred_phase_rad': float(theta), 'preferred_phase_deg': float(theta_deg), 'mean_vector_length': float(rad), 'rho': float(rho), 'rayleigh_z': float(z), 'rayleigh_p': float(pv2), 'n_segments': len(segments), 'outputfile': outputfile, 'amp_file': f"{amp_file}.npy" } except Exception as e: logger.error(f"Error processing channel {ch}: {e}") import traceback traceback.print_exc() flag += 1 continue # Add results to tracking if ch not in tracking['event_pac']: tracking['event_pac'][ch] = {} # Create a key based on parameters key = f"{phase_freq[0]}-{phase_freq[1]}Hz_{amp_freq[0]}-{amp_freq[1]}Hz" tracking['event_pac'][ch][key] = chan_results # Check completion status if flag == 0: logger.info("Phase-amplitude coupling analysis finished without errors") else: logger.warning(f"Phase-amplitude coupling analysis finished with {flag} warnings/errors") return tracking['event_pac']
def _mean_amp(self, pha, amp, nbins=18): """ Calculate mean amplitude in phase bins. Parameters ---------- pha : array Phase time series amp : array Amplitude time series nbins : int Number of phase bins Returns ------- array Mean amplitude in each phase bin """ # Convert phase to bin indices phase_bins = np.linspace(-np.pi, np.pi, nbins + 1) phase_bins_indices = np.digitize(pha, phase_bins) - 1 phase_bins_indices[phase_bins_indices == nbins] = 0 # Calculate mean amplitude in each bin mean_amp_bins = np.zeros(nbins) for i in range(nbins): bin_mask = phase_bins_indices == i if np.any(bin_mask): mean_amp_bins[i] = np.mean(amp[bin_mask]) return mean_amp_bins
[docs] def generate_comodulogram(self, chan=None, stage=None, phase_freqs=None, amp_freqs=None, idpac=(2, 3, 4), buffer=1.0, out_dir=None, reject_artf=['Artefact', 'Arousal']): """ Generate a comodulogram for the given channel and parameters. Parameters ---------- chan : str Channel to analyze stage : list or str Sleep stage(s) to analyze phase_freqs : list of tuples List of phase frequency bands to analyze amp_freqs : list of tuples List of amplitude frequency bands to analyze idpac : tuple PAC method settings (method, surrogate, correction) buffer : float Buffer in seconds out_dir : str Output directory for results reject_artf : list Event types to reject Returns ------- dict Dictionary containing comodulogram results """ from tensorpac import Pac logger = self.logger # NEED TO FIX STAGE ISN NREM2NREM3 <=============================== # Process stage input if isinstance(stage, str): parsed_stages = [] # Common stage names to look for known_stages = ["NREM1", "NREM2", "NREM3", "REM", "Wake"] for known_stage in known_stages: if known_stage in stage: parsed_stages.append(known_stage) if parsed_stages: logger.info(f"Parsed stage string '{stage}' into: {parsed_stages}") stage = parsed_stages else: # If no known stages found, treat it as a single stage stage = [stage] logger.warning(f"Could not parse stage string '{stage}', treating as a single stage") # Set default phase and amplitude frequencies if not provided if phase_freqs is None: phase_freqs = [(0.5, 1.5), (1.5, 4), (4, 8), (8, 13)] if amp_freqs is None: amp_freqs = [(8, 13), (13, 30), (30, 45), (55, 95)] # Set up output directory if out_dir is None: out_dir = os.path.join(self.rootpath, "wonambi", "pac_results") os.makedirs(out_dir, exist_ok=True) # Fetch data segments try: logger.info(f"Fetching data segments for channel {chan}") # Fetch segments based on sleep stage segments = fetch(self.dataset, self.annotations, cat=(1, 1,1,0), evt_type=None, stage=stage, cycle=None, buffer=buffer, reject_artf=reject_artf) # Read data for the channel segments.read_data(chan) if not segments or len(segments) == 0: logger.warning(f"No valid data segments found for {chan}") return None logger.info(f"Processing {len(segments)} data segments") # Concatenate data from all segments all_data = [] for seg in segments: data = seg['data'] all_data.append(data()[0][0]) # Concatenate data if all_data: data_array = np.concatenate(all_data) # Calculate sampling frequency s_freq = segments[0]['data'].s_freq # Create PAC object pac = Pac(idpac=idpac, verbose='ERROR') # Prepare phase and amplitude frequency ranges p_freqs = np.array([list(pf) for pf in phase_freqs]) a_freqs = np.array([list(af) for af in amp_freqs]) # Calculate comodulogram logger.info("Calculating comodulogram...") comod = pac.filterfit(s_freq, data_array, p_freqs, a_freqs, n_perm=200, progress_bar=True, random_state=42) # Save results stagename = '-'.join(stage) output_file = f"{out_dir}/comodulogram_{chan}_{stagename}.npz" np.savez(output_file, comod=comod, p_freqs=p_freqs, a_freqs=a_freqs, idpac=idpac, chan=chan, stage=stage) logger.info(f"Saved comodulogram to {output_file}") # Create and save plot fig = Figure(figsize=(10, 8), dpi=100) ax = fig.add_subplot(111) # Create meshgrid for plotting p_centers = [(p[0] + p[1])/2 for p in phase_freqs] a_centers = [(a[0] + a[1])/2 for a in amp_freqs] # Plot comodulogram as heatmap im = ax.imshow(comod, cmap='viridis', aspect='auto', extent=[p_centers[0], p_centers[-1], a_centers[0], a_centers[-1]], origin='lower') # Add colorbar cbar = fig.colorbar(im, ax=ax) cbar.set_label('PAC Strength') # Add labels ax.set_xlabel('Phase Frequency (Hz)') ax.set_ylabel('Amplitude Frequency (Hz)') ax.set_title(f'PAC Comodulogram - {chan} - {stagename}') # Set y-axis to log scale for better visualization ax.set_yscale('log') # Add frequency band labels ax.set_xticks([p[0] for p in phase_freqs] + [phase_freqs[-1][1]]) ax.set_yticks([a[0] for a in amp_freqs] + [amp_freqs[-1][1]]) # Save figure fig_file = f"{out_dir}/comodulogram_{chan}_{stagename}.png" fig.savefig(fig_file, dpi=300, bbox_inches='tight') logger.info(f"Saved comodulogram plot to {fig_file}") return { 'comod': comod, 'p_freqs': p_freqs, 'a_freqs': a_freqs, 'output_file': output_file, 'fig_file': fig_file } else: logger.warning("No data segments to process") return None except Exception as e: logger.error(f"Error generating comodulogram: {e}") import traceback traceback.print_exc() return None
[docs] def compare_conditions(self, condition1, condition2, test_type='watson_williams', alpha=0.05, out_dir=None): """ Compare PAC between two conditions. Parameters ---------- condition1 : dict First condition with keys 'amp_file', 'stage', etc. condition2 : dict Second condition with keys 'amp_file', 'stage', etc. test_type : str Type of statistical test ('watson_williams' or 'permutation') alpha : float Significance level out_dir : str Output directory for results Returns ------- dict Dictionary containing comparison results """ logger = self.logger # Set up output directory if out_dir is None: out_dir = os.path.join(self.rootpath, "wonambi", "pac_results") os.makedirs(out_dir, exist_ok=True) # Load data from condition files try: # Load amplitude data amp1 = np.load(condition1['amp_file']) amp2 = np.load(condition2['amp_file']) # Get number of bins nbins = amp1.shape[1] # Create bins for preferred phase vecbin = np.zeros(nbins) width = 2 * np.pi / nbins for n in range(nbins): vecbin[n] = n * width + width / 2 # Find preferred phase for each trial ab_pk1 = np.argmax(amp1, axis=1) ab_pk2 = np.argmax(amp2, axis=1) # Convert to angles angles1 = vecbin[ab_pk1] angles2 = vecbin[ab_pk2] # Perform statistical test if test_type == 'watson_williams': from scipy.stats import circmean from pingouin import circ_r # Calculate mean direction for each condition theta1 = circmean(angles1) theta2 = circmean(angles2) # Calculate mean vector length for each condition r1 = circ_r(vecbin, np.histogram(ab_pk1, bins=nbins)[0], d=width) r2 = circ_r(vecbin, np.histogram(ab_pk2, bins=nbins)[0], d=width) # Perform Watson-Williams test try: from pingouin import circ_wwtest # Run Watson-Williams test F, p = circ_wwtest(angles1, angles2, np.ones(angles1.shape), np.ones(angles2.shape)) # Save results cond1_name = condition1.get('name', 'Condition1') cond2_name = condition2.get('name', 'Condition2') output_file = f"{out_dir}/pac_comparison_{cond1_name}_vs_{cond2_name}.csv" results_df = pd.DataFrame({ 'Condition1': [cond1_name], 'Condition2': [cond2_name], 'Condition1_PP_rad': [theta1], 'Condition1_PP_deg': [np.degrees(theta1)], 'Condition1_MVL': [r1], 'Condition1_n': [len(angles1)], 'Condition2_PP_rad': [theta2], 'Condition2_PP_deg': [np.degrees(theta2)], 'Condition2_MVL': [r2], 'Condition2_n': [len(angles2)], 'F': [F], 'p': [p], 'Significant': [p < alpha] }) results_df.to_csv(output_file, index=False) logger.info(f"Saved comparison results to {output_file}") # Create and save plot fig = Figure(figsize=(10, 8), dpi=100) ax = fig.add_subplot(111, polar=True) # Calculate mean amplitudes for each condition mean_amp1 = np.nanmean(amp1, axis=0) mean_amp1 = mean_amp1 / mean_amp1.sum() mean_amp2 = np.nanmean(amp2, axis=0) mean_amp2 = mean_amp2 / mean_amp2.sum() # Create angles for plotting angles = np.linspace(0, 2*np.pi, nbins, endpoint=False) # Plot data ax.bar(angles, mean_amp1, width=width, alpha=0.5, label=cond1_name) ax.bar(angles, mean_amp2, width=width, alpha=0.5, label=cond2_name) # Add preferred phase markers ax.plot([theta1, theta1], [0, np.max(mean_amp1)*1.2], 'r-', linewidth=2) ax.plot([theta2, theta2], [0, np.max(mean_amp2)*1.2], 'b-', linewidth=2) # Add labels and title ax.set_title(f'PAC Comparison\n{cond1_name} vs {cond2_name}\nF={F:.2f}, p={p:.4f}') ax.set_theta_zero_location('N') # 0 at the top ax.set_theta_direction(-1) # clockwise # Add legend ax.legend() # Save figure fig_file = f"{out_dir}/pac_comparison_{cond1_name}_vs_{cond2_name}.png" fig.savefig(fig_file, dpi=300, bbox_inches='tight') logger.info(f"Saved comparison plot to {fig_file}") return { 'condition1': cond1_name, 'condition2': cond2_name, 'theta1': theta1, 'theta2': theta2, 'r1': r1, 'r2': r2, 'F': F, 'p': p, 'significant': p < alpha, 'output_file': output_file, 'fig_file': fig_file } except Exception as e: logger.error(f"Error performing Watson-Williams test: {e}") import traceback traceback.print_exc() return None elif test_type == 'permutation': # Implement permutation test for PAC comparison logger.error("Permutation test not implemented yet") return None else: logger.error(f"Unknown test type: {test_type}") return None except Exception as e: logger.error(f"Error comparing conditions: {e}") import traceback traceback.print_exc() return None
[docs] def export_pac_parameters_to_csv(self, json_dir=None, csv_file=None, channels=None, stages=None, phase_freq=None, amp_freq=None, append=True, method_info=None, out_dir=None): """ Export PAC parameters from tracking to a CSV file. Parameters ---------- json_dir : str Directory containing JSON files or individual channel CSV files csv_file : str Output CSV file channels : list List of channels to include stages : list List of sleep stages to include phase_freq : tuple Phase frequency range amp_freq : tuple Amplitude frequency range append : bool If True, append to existing CSV file by channel rather than overwrite method_info : dict Dictionary containing method information (sw_method, spindle_method) out_dir : str Base output directory to use Returns ------- dict Dictionary containing export results """ logger = self.logger # First, determine the base directory base_dir = out_dir if out_dir else json_dir if base_dir is None: base_dir = os.path.join(self.rootpath, "wonambi", "pac_results") # Create method-specific directory path method_dir = base_dir if method_info: sw_method = method_info.get('sw_method', 'unknown') spindle_method = method_info.get('spindle_method', 'unknown') event_type = method_info.get('event_type', 'unknown') stage = method_info.get('stage', 'all') # Create stage string stage_str = ''.join(stage) if isinstance(stage, list) else str(stage) # Determine method directory if event_type == 'slow_wave' and method_info.get('pair_with_spindles', False): method_dir_name = f"{sw_method}_paired_{spindle_method}" else: method_dir_name = sw_method if event_type == 'slow_wave' else spindle_method # Create full method directory path method_dir = os.path.join(base_dir, method_dir_name, stage_str) # Ensure directory exists os.makedirs(method_dir, exist_ok=True) logger.info(f"Using method directory: {method_dir}") # Create frequency string for filename freq_str = "" if phase_freq and amp_freq: ph_str = f"{phase_freq[0]}-{phase_freq[1]}Hz" amp_str = f"{amp_freq[0]}-{amp_freq[1]}Hz" freq_str = f"{ph_str}_{amp_str}" # Determine output CSV file if csv_file is None: if freq_str: csv_file = os.path.join(method_dir, f"pac_summary_{phase_freq[0]}-{phase_freq[1]}Hz_{amp_freq[0]}-{amp_freq[1]}Hz.csv") else: csv_file = os.path.join(method_dir, "pac_summary.csv") logger.info(f"Output summary CSV file: {csv_file}") # First approach: Look for individual channel result files # For PAC data, we need to look for files with pattern: # E*_slowwave_spindle_coupling_pha-FREQ-fixed_amp-FREQ-fixed_pac_parameters.csv if method_info and method_info.get('pair_with_spindles', False): # For SW-Spindle coupling file_pattern = f"*_slowwave_spindle_coupling_pha-{phase_freq[0]}-{phase_freq[1]}Hz-fixed_amp-{amp_freq[0]}-{amp_freq[1]}Hz-fixed_pac_parameters.csv" else: # For other coupling types file_pattern = f"*_pha-{phase_freq[0]}-{phase_freq[1]}Hz-fixed_amp-{amp_freq[0]}-{amp_freq[1]}Hz-fixed_pac_parameters.csv" # Find all matching channel CSV files channel_files = [] try: import glob channel_files = glob.glob(os.path.join(method_dir, file_pattern)) logger.info(f"Found {len(channel_files)} individual channel PAC parameter files") except Exception as e: logger.error(f"Error finding channel files: {e}") # If we found individual channel files, use them to build the summary if channel_files: try: import pandas as pd # Store all channel data all_data = [] # Process each file for file in channel_files: try: # Extract channel name from filename filename = os.path.basename(file) channel = filename.split('_')[0] # Assuming format: E101_slowwave_... # Read channel data df = pd.read_csv(file) if not df.empty: # Add channel data to combined list for _, row in df.iterrows(): # Create data row data_row = { 'Channel': channel, 'Phase_Freq': f"{phase_freq[0]}-{phase_freq[1]}", 'Amp_Freq': f"{amp_freq[0]}-{amp_freq[1]}", } # Copy relevant metrics metric_cols = [ 'mi_raw', 'mi_norm', 'median_mi_pval', 'preferred_phase_rad', 'preferred_phase_deg', 'mean_vector_length', 'rho', 'rayleigh_z', 'rayleigh_p' ] for col in metric_cols: if col in row: data_row[col] = row[col] all_data.append(data_row) logger.info(f"Processed data from {file}") except Exception as e: logger.error(f"Error processing {file}: {e}") # Create summary dataframe if all_data: summary_df = pd.DataFrame(all_data) # Check if we should append to existing file if append and os.path.exists(csv_file): # Read existing data try: existing_df = pd.read_csv(csv_file) # Create set of existing channels existing_channels = set() if 'Channel' in existing_df.columns: for _, row in existing_df.iterrows(): ch = row['Channel'] ph_freq = row['Phase_Freq'] if 'Phase_Freq' in row else "" amp_freq = row['Amp_Freq'] if 'Amp_Freq' in row else "" existing_channels.add(f"{ch}_{ph_freq}_{amp_freq}") # Filter out channels that already exist new_data = [] for row in all_data: ch = row['Channel'] ph_freq = row['Phase_Freq'] amp_freq = row['Amp_Freq'] key = f"{ch}_{ph_freq}_{amp_freq}" if key not in existing_channels: new_data.append(row) # Append new data to existing data if new_data: new_df = pd.DataFrame(new_data) summary_df = pd.concat([existing_df, new_df]) logger.info(f"Appending {len(new_data)} new channels to existing summary") else: summary_df = existing_df logger.info("No new data to append") except Exception as e: logger.error(f"Error appending to existing file: {e}, creating new file") # Write summary to CSV summary_df.to_csv(csv_file, index=False) logger.info(f"Exported PAC summary to {csv_file} with {len(summary_df)} entries") return { 'file': csv_file, 'channels': len(summary_df['Channel'].unique()), 'rows': len(summary_df) } else: logger.warning("No PAC data to export") return None except Exception as e: logger.error(f"Error creating summary from files: {e}") import traceback traceback.print_exc() # Second approach: Use tracking data if available and no files were found elif 'event_pac' in self.tracking and self.tracking['event_pac']: try: # Filter channels if specified if channels is None: channels = list(self.tracking['event_pac'].keys()) else: channels = [ch for ch in channels if ch in self.tracking['event_pac']] # Create key based on frequency bands key = None if phase_freq and amp_freq: key = f"{phase_freq[0]}-{phase_freq[1]}Hz_{amp_freq[0]}-{amp_freq[1]}Hz" # Read existing data if appending existing_data = {} if append and os.path.exists(csv_file): try: import pandas as pd # Read existing CSV into DataFrame existing_df = pd.read_csv(csv_file) logger.info(f"Read {len(existing_df)} existing entries from {csv_file}") # Convert DataFrame to dictionary keyed by channel for _, row in existing_df.iterrows(): ch = row['Channel'] if ch not in existing_data: existing_data[ch] = {} # Create frequency key from Phase_Freq and Amp_Freq ph_freq = row['Phase_Freq'] if 'Phase_Freq' in row else "" amp_freq = row['Amp_Freq'] if 'Amp_Freq' in row else "" freq_key = f"{ph_freq}_{amp_freq}" # Store row data existing_data[ch][freq_key] = row.to_dict() except Exception as e: logger.warning(f"Could not read existing CSV for appending: {e}") existing_data = {} # Prepare data for export data = [] for ch in channels: if ch not in self.tracking['event_pac']: continue ch_results = self.tracking['event_pac'][ch] if key and key in ch_results: # Use specific frequency key results = ch_results[key] # Check if already in existing data skip_channel = False if append and ch in existing_data: for ex_key, ex_data in existing_data[ch].items(): # See if there's a matching frequency entry if ex_key.startswith(f"{phase_freq[0]}-{phase_freq[1]}") and \ ex_key.endswith(f"{amp_freq[0]}-{amp_freq[1]}"): # Check if existing has more segments if ex_data.get('N_Segments', 0) > results.get('n_segments', 0): logger.info(f"Skipping {ch}/{key}: existing has more segments") data.append(ex_data) skip_channel = True break if not skip_channel: data.append({ 'Channel': ch, 'Phase_Freq': f"{phase_freq[0]}-{phase_freq[1]}", 'Amp_Freq': f"{amp_freq[0]}-{amp_freq[1]}", 'MI': results.get('mi_norm', float('nan')), 'MI_pval': results.get('pval', float('nan')), 'PP_rad': results.get('preferred_phase_rad', float('nan')), 'PP_degrees': results.get('preferred_phase_deg', float('nan')), 'Mean_vector_length': results.get('mean_vector_length', float('nan')), 'rho': results.get('rho', float('nan')), 'Rayleigh_z': results.get('rayleigh_z', float('nan')), 'Rayleigh_p': results.get('rayleigh_p', float('nan')), 'N_Segments': results.get('n_segments', 0) }) else: # Export all frequency combinations for freq_key, results in ch_results.items(): try: # Parse frequency ranges from key freq_parts = freq_key.split('_') ph_freq = freq_parts[0] amp_freq = freq_parts[1] # Check if already in existing data skip_entry = False if append and ch in existing_data: for ex_key, ex_data in existing_data[ch].items(): if ex_key == freq_key: # Check if existing has more segments if ex_data.get('N_Segments', 0) > results.get('n_segments', 0): logger.info(f"Skipping {ch}/{freq_key}: existing has more segments") data.append(ex_data) skip_entry = True break if not skip_entry: data.append({ 'Channel': ch, 'Phase_Freq': ph_freq, 'Amp_Freq': amp_freq, 'MI': results.get('mi_norm', float('nan')), 'MI_pval': results.get('pval', float('nan')), 'PP_rad': results.get('preferred_phase_rad', float('nan')), 'PP_degrees': results.get('preferred_phase_deg', float('nan')), 'Mean_vector_length': results.get('mean_vector_length', float('nan')), 'rho': results.get('rho', float('nan')), 'Rayleigh_z': results.get('rayleigh_z', float('nan')), 'Rayleigh_p': results.get('rayleigh_p', float('nan')), 'N_Segments': results.get('n_segments', 0) }) except Exception as e: logger.warning(f"Could not parse frequency key: {freq_key} - {e}") # Create DataFrame and export to CSV if data: import pandas as pd df = pd.DataFrame(data) # If append and file exists, merge with existing data if append and os.path.exists(csv_file): try: existing_df = pd.read_csv(csv_file) # Only keep rows from existing_df that aren't already in our new data combined_df = pd.concat([existing_df, df]).drop_duplicates( subset=['Channel', 'Phase_Freq', 'Amp_Freq'], keep='last' ) combined_df.to_csv(csv_file, index=False) logger.info(f"Appended to existing CSV: {len(df)} new rows, {len(combined_df)} total rows") except Exception as e: logger.error(f"Error appending to existing CSV: {e}") df.to_csv(csv_file, index=False) logger.info(f"Created new CSV with {len(df)} rows") else: df.to_csv(csv_file, index=False) logger.info(f"Created new CSV with {len(df)} rows") return {'file': csv_file, 'channels': len(channels), 'rows': len(data)} else: logger.warning("No PAC data to export") return None except Exception as e: logger.error(f"Error exporting PAC parameters from tracking: {e}") import traceback traceback.print_exc() return None else: logger.warning("No PAC results in tracking dictionary or individual files") return None