ihkapy.sm_make_wavelet_bank
Author statement
This file is based off of sm_getPowerPerChannel.m written by Samuel McKenzie, and awt_freqlist.m and Maureen Clerc, Christian Benar, october 2007 Translated and adapted to python in May 2022 by Stephen Fay dcxstephen@gmail.com
1""" 2Author statement 3---------------- 4This file is based off of sm_getPowerPerChannel.m written by Samuel McKenzie, 5and awt_freqlist.m and Maureen Clerc, Christian Benar, october 2007 6Translated and adapted to python in May 2022 by Stephen Fay dcxstephen@gmail.com 7""" 8 9# TODO: standardize name 'channel' to 'raw_ch_idx' 10# TODO: standardize naming convention 'amp' and 'power' are used as synonyms 11# I think 'amp' is better because it looks less like 'phase', easy distinguish 12 13from ihkapy.fileio.binary_io import merge_dats # local dependency 14from ihkapy.fileio import utils 15from ihkapy.fileio.options_io import load_fio_ops_and_data_ops 16import os # I/O 17import shutil # I/O 18from tqdm import tqdm # Progressbar 19import logging # For debugging and following code 20import warnings # Bulletproof code 21import re # Regexp library, to bulletproof code 22import pyedflib # Read from edf files | this is "terrible in MatLab", look into it, apparently it loads everything into RAM 23import numpy as np # Array manipulation, Scientific computing 24from numpy.fft import fft, ifft # Signal processing 25from scipy.stats import zscore # Signal processing 26 27 28# Init logger and set the logging level 29logging.basicConfig(level=logging.DEBUG) 30logger = logging.getLogger(__name__) 31 32# For variables containing strings of with absolute path, we explicitly 33# include the word "path" in the variable name. For those with relative 34# or leaf paths, we do not put "path" in the name. 35 36# TODO: implement for Lusin and Sombrero wavelets too 37# Note, Lusin wasn't implemented in Matlab, and pipeline only used Gabor 38# 39# compute_wavelet_gabor corresponds to awt_freqlist in the MatLab code 40# 41# % Maureen Clerc, Christian Benar, october 2007 42# % modified from awt from wavelab toolbox 43# 44# % History of changes 45# % 1/11/2007: (cgb) psi_array: output in complex 46# % 3/06/2008: (cgb) init of psi_array to size of wt 47# 3/05/2022: SF translated awt_freqlist to 48 49def compute_wavelet_gabor( 50 signal: np.ndarray, 51 fs: int or float, 52 freqs: list or float, 53 xi: int = 5 # only needed for Gabor 54 ) -> np.ndarray: 55 """Computes one or multiple wavelet transforms of the input signal. 56 57 Follows awt_freqlist.m from the buzzcode repository. 58 59 Parameters 60 ---------- 61 `signal : np.ndarray` 62 The input signal. Only accepts 1D signals. 63 64 `fs : int or float` 65 The sampling frequency. 66 67 `freqs : list or float` 68 The frequency or list of frequencies to compute. 69 70 `xi : int` 71 The number of oscillations parameter, only needed for Gabor wavelet. 72 73 Returns 74 ------- 75 `np.ndarray` 76 A numpy array of dim (len(freqs),len(signal)) 77 """ 78 # Make sure all types are correct 79 if isinstance(freqs, float) or isinstance(freqs, int): freqs = [freqs] 80 freqs = np.asarray(freqs) 81 signal = np.asarray(signal) 82 assert fs > 0 and (isinstance(fs, float) or isinstance(fs, int)) 83 assert signal.ndim == 1, "Must be single dim signal" 84 # TODO: implement multi-dim and remove above assertion 85 # (not crucial because we don't (yet) use that in pipeline) 86 87 (len_sig,) = signal.shape 88 sigma2 = 1 89 omega = np.concatenate((np.arange(0,len_sig//2+1) , np.arange(-((len_sig+1)//2)+1,0))) * fs / len_sig 90 # omega *= fs / len_sig 91 92 # Warning: this code was dogmatically translated from MatLab repo 93 tolerance = 0.5 94 mincenterfreq = 2*tolerance*np.sqrt(sigma2)*fs*xi / len_sig 95 maxcenterfreq = fs*xi/(xi+tolerance/np.sqrt(sigma2)) # Shouldn't this be divided by two because of aliasing? 96 nyquist = fs / 2 97 maxcenterfreq = min(maxcenterfreq,nyquist) 98 logger.debug(f"fs = {fs}") 99 logger.debug(f"freqs = {freqs}") 100 logger.debug(f"\n\tLowest freq = {min(freqs)}\n\tHighest freq = {max(freqs)}") 101 logger.debug(f"\n\tmincenterfreq = {mincenterfreq}\n\tmaxcenterfreq = {maxcenterfreq}") 102 103 s_arr = xi / freqs 104 minscale = xi / maxcenterfreq 105 maxscale = xi / mincenterfreq 106 # reject frequencies that are outside the given scale 107 if ((s_arr >= minscale) | (s_arr <= maxscale)).any(): 108 warnings.warn("Frequencies are not between minscale and maxscale.") 109 110 n_freqs = len(freqs) 111 # np.complex64 is numpy's coarsest complex numpy type 112 wt = np.zeros((len_sig,n_freqs),dtype=np.complex64) 113 114 for idx,s in enumerate(s_arr): 115 freq = (s * omega - xi) 116 psi = np.power(4*np.pi*sigma2,0.25) * np.sqrt(s) * np.exp(-sigma2/2 * freq*freq) 117 wt[:,idx] = ifft(fft(signal) * psi) 118 119 return np.squeeze(wt) # turns 2d into 1d IFF single freq 120 121 122# Helper, test to make sure our cache folder is not corrupted 123def _assert_all_ext_type_match_regexp( 124 directory: str, 125 extension: str, 126 regexp_base: str): 127 for fname in os.listdir(directory): 128 base,ext = os.path.splitext(fname) 129 if ext == extension: assert bool(re.search(regexp_base,base)) 130 logger.debug(f"Test passed: all '{extension}' files in {directory} match the regexp:\n{regexp_base}") 131 return 132 133 134def make_wavelet_bank( 135 edf_fname:str, 136 fio_ops:dict, 137 data_ops:dict): 138 """Computes and saves a wavelet decomposition of each channel. 139 140 Uses dictionaries loaded from user defined options from Options.toml 141 (options_filepath) file to compute the Gabor wavelet decomposition 142 of the raw signals in the provided edf file (edf_fname). 143 This function doesn't return anything, but reads and writes to disk. 144 145 The signals are scaled before saving to hard disk, this is to mitigate 146 quantization effects, since we are saving our data as int16. 147 148 - Reads edf raw signal specified by edf_fname (and fio_ops params) 149 - Iterates through each channel, computing wavelet convolutions 150 for frequencies in a range specified by data_ops 151 - Saves output binaries, one binary file for each hardware channel, 152 all the frequencies are saved according to the below order 153 154 Binaries array flattening convention: 155 - Read 'sn' as 'sample number n' 156 - A is for Amplitude (=Power), and P is for Phase 157 - K is the index of the last frequency (= num of freqs - 1) 158 [raw_s0,freq00_A_s0,freq00_P_s0,freq01_A_s0,freq01_P_s0,...,freqk_A_s0, 159 freqK_P_s0,raw_s1,freq00_A_s1,freq00_P_s1,...,freqK_A_s1,freqK_P_s1,... 160 ... 161 raw_sn,freq00_A_sn,freq00_P_sn,freq01_A_s0,...,freqK_P_sn] 162 163 Note: it is important the above convention is respected because this is 164 how the binary_io tools read the files. It's the same convention as the 165 MatLab suit. 166 167 Parameters 168 ---------- 169 170 `edf_fname` 171 The name of the '.edf' raw data file. We look for all edf files 172 in fio_ops.RAW_DATA_PATH from Options.toml 173 174 `fio_ops : dict` 175 The fio parameters defined in the Options.toml config file. 176 177 `data_ops : dict` 178 Data parameters from the Options.toml config file. 179 The Options.toml config file contains user defined parameters, 180 including fio params, data and data-processing params, ML 181 model hyper-params, and training params. 182 """ 183 assert len(edf_fname.split("."))==2, f"There can be no dots in the base file name, {edf_fname}" 184 basename,ext = edf_fname.split(".") 185 assert ext == "edf", f"Incorrect file format, expected .edf, got {edf_fname}" 186 187 # Unpack File IO constants 188 RAW_DATA_PATH = fio_ops["RAW_DATA_PATH"] 189 WAVELET_BINARIES_PATH = fio_ops["WAVELET_BINARIES_PATH"] 190 # Unpack Data config constants 191 FS = data_ops["FS"] 192 NUM_FREQ = data_ops["NUM_FREQ"] 193 LOW_FREQ = data_ops["LOW_FREQ"] 194 HIGH_FREQ = data_ops["HIGH_FREQ"] 195 SPACING = data_ops["SPACING"] 196 ZSCORE_POWER = data_ops["ZSCORE_POWER"] # bool 197 SCALE_RAW = data_ops["SCALE_RAW"] 198 SCALE_PHASE = data_ops["SCALE_PHASE"] 199 SCALE_POWER = data_ops["SCALE_POWER"] 200 N_CHAN_RAW = data_ops["N_CHAN_RAW"] 201 TS_FEATURES = data_ops["TS_FEATURES"] 202 # Define features space (based on num, lo, high-freq) 203 FREQS = np.logspace(np.log10(LOW_FREQ), np.log10(HIGH_FREQ), NUM_FREQ) 204 205 ### IO Checks 206 # Check edf file and wavelet binaries directory exist 207 edf_path = os.path.join(RAW_DATA_PATH , edf_fname) 208 assert os.path.exists(edf_path), f"Invalid edf file path. Make sure edf file exists and is inside of the {RAW_DATA_PATH} directory." 209 assert os.path.exists(WAVELET_BINARIES_PATH), f"Invalid path {WAVELET_BINARIES_PATH}\nCheck your configuration file Options.toml" 210 # Check if the cache folder for binaries exists, delete it 211 cache_dir_path = os.path.join(WAVELET_BINARIES_PATH, "cache") 212 if os.path.exists(cache_dir_path): 213 logger.info("Deleting cache") 214 shutil.rmtree(cache_dir_path) 215 else: logger.debug("No cache, proceeding") 216 217 # Read edf file and loop through each channel one at a time 218 for channel in range(N_CHAN_RAW): 219 # TODO: simplify and sacrifice some of the checks in this untidy loop 220 # in the name of tidyness and readability and the good lord 221 os.mkdir(cache_dir_path) # Create the cache directory 222 sig = None 223 with pyedflib.EdfReader(edf_path) as f: 224 assert N_CHAN_RAW == f.signals_in_file, f"N_CHAN_RAW={N_CHAN_RAW} incompatible with detected number of channels in file ={f.signals_in_file}" 225 sig = f.readSignal(channel) # .astype("int16") # this would quantize signal badly 226 _rms = np.sqrt(np.power(sig,2).mean()) # for debugging only 227 logger.debug(f"Sample of quantized sig {sig.astype('int16')[:10]}") 228 logger.debug(f"Root Mean Square raw signal = {_rms:.3f}") 229 assert sig.shape==(f.getNSamples()[0],) # make sure exists and right shape 230 231 # Save raw channel data as .dat binary 232 cached_bin_fname_raw = f"{basename}_ch_{str(channel).zfill(3)}_raw.dat" 233 logger.debug(f"cache_dir_path = '{cache_dir_path}'") 234 cached_binary_raw_path = utils.fmt_binary_cache_wav_path( 235 cache_dir_path, 236 basename, 237 channel, 238 "RAW") 239 (sig * SCALE_RAW).astype("int16").tofile(cached_binary_raw_path) 240 # TODO: delete below comment 241 # (sig * SCALE_RAW).astype("int16").tofile(os.path.join(cache_dir_path, cached_bin_fname_raw)) 242 243 print(f"Computing {NUM_FREQ} Gabor wavelet convolutions for channel {channel}.") 244 # Loop through each frequency, convolve with Gabor wavelet 245 for f_idx,freq in tqdm(enumerate(FREQS)): 246 # Define cache filepath for this channel & freq 247 cached_binary_amp_path,cached_binary_phase_path = utils.fmt_binary_cache_wav_path( 248 cache_dir_path, 249 basename, 250 channel, 251 "AMP-PHASE", 252 f_idx) 253 254 # Convolve signal with the the wavelet, similar to awt_freqs 255 wt = compute_wavelet_gabor(signal=sig,fs=FS,freqs=[freq]) 256 257 # TODO: do we need this condition? Won't we always use wavelet power? 258 # if we don't use wavelet power it might break some of the code 259 # Conditionally Zscore, re-scale, and save the power 260 if "wavelet_power" in TS_FEATURES: 261 wt_power = np.abs(wt) # deep copy 262 if ZSCORE_POWER==True: 263 wt_power = zscore(np.abs(wt)) 264 # Comment in Options.toml: SCALE_POWER should be smaller if no zscore 265 logger.debug(f"wt_power.dtype={wt_power.dtype}") 266 wt_power = (wt_power * SCALE_POWER).astype("int16") 267 wt_power.tofile(cached_binary_amp_path) 268 269 # Conditionally re-scale and save the phase 270 if "wavelet_phase" in TS_FEATURES: 271 wt_phase = (np.arctan(np.real(wt) / np.imag(wt)) * SCALE_PHASE).astype("int16") 272 wt_phase.tofile(cached_binary_phase_path) # , format="int16") 273 logger.info("Finished computing wavelet transforms for channel={channel}.") 274 275 # Check all the dat files in the cache match with the regex—– 276 # this is a test to make sure our cached folder is not corrupted 277 # Filenames must not contain any dots! 278 # Regexp must match all the Amplicude (=Power) and Phase transforms 279 # as well as the raw .dat binary. 280 regex = f"^{basename}_ch_{str(channel).zfill(3)}_(freqidx_\d\d|0RAW)(_A|_P|)$" 281 _assert_all_ext_type_match_regexp( 282 directory = cache_dir_path, 283 extension = "dat", 284 regexp_base = regex) 285 286 ### Merge all the cached frequency .dat files into a single one. 287 # Sort the cached binaries, this will put lower f_idx first and 288 # alternate A before P, exact same order they are created 289 # All of the files in the cache have the same channel number (checked above) 290 sorted_cache_binaries = [i for i in os.listdir(cache_dir_path) if i[-4:]==".dat"] 291 sorted_cache_binaries.sort() # The are already be sorted, this line 292 # is mainly to aid the reader 293 # Sanity check assert 294 assert "RAW" in sorted_cache_binaries[0], "Error in sorting the binaries, first file must be the raw channel" 295 scb_paths = [os.path.join(cache_dir_path,i) for i in sorted_cache_binaries] 296 merge_dats( 297 fpaths_in = scb_paths, 298 dir_out = WAVELET_BINARIES_PATH, 299 fname_out = f"{basename}_ch_{str(channel).zfill(3)}.dat") 300 301 # Delete the cached single-channel binary files 302 shutil.rmtree(cache_dir_path) 303 return 304 305 306def make_wavelet_bank_all(options_path="Options.toml"): 307 # Unpack parameters and user-defined constants 308 fio_ops,data_ops = load_fio_ops_and_data_ops(options_path) 309 # Bind the consts we use to local vars for readability 310 RAW_DATA_PATH = fio_ops["RAW_DATA_PATH"] 311 312 edf_files = [i for i in os.listdir(RAW_DATA_PATH) if os.path.splitext(i)[1]==".edf"] 313 print(f"Make wavelet bank all, convolving {len(edf_files)} edf files.") 314 for edf_fname in edf_files: 315 logger.info(f"Running make_wavelet_bank on {edf_fname}") 316 # Convolve with wavelets and write binary files 317 make_wavelet_bank(edf_fname, fio_ops, data_ops) 318 319 return 320 321 322# Next block runs only when you run this file directly, not on import 323if __name__ == "__main__": 324 # make_wavelet_bank_all(options_path="Options.toml") 325 make_wavelet_bank_all(options_path="Options_test.toml") 326 327 328
50def compute_wavelet_gabor( 51 signal: np.ndarray, 52 fs: int or float, 53 freqs: list or float, 54 xi: int = 5 # only needed for Gabor 55 ) -> np.ndarray: 56 """Computes one or multiple wavelet transforms of the input signal. 57 58 Follows awt_freqlist.m from the buzzcode repository. 59 60 Parameters 61 ---------- 62 `signal : np.ndarray` 63 The input signal. Only accepts 1D signals. 64 65 `fs : int or float` 66 The sampling frequency. 67 68 `freqs : list or float` 69 The frequency or list of frequencies to compute. 70 71 `xi : int` 72 The number of oscillations parameter, only needed for Gabor wavelet. 73 74 Returns 75 ------- 76 `np.ndarray` 77 A numpy array of dim (len(freqs),len(signal)) 78 """ 79 # Make sure all types are correct 80 if isinstance(freqs, float) or isinstance(freqs, int): freqs = [freqs] 81 freqs = np.asarray(freqs) 82 signal = np.asarray(signal) 83 assert fs > 0 and (isinstance(fs, float) or isinstance(fs, int)) 84 assert signal.ndim == 1, "Must be single dim signal" 85 # TODO: implement multi-dim and remove above assertion 86 # (not crucial because we don't (yet) use that in pipeline) 87 88 (len_sig,) = signal.shape 89 sigma2 = 1 90 omega = np.concatenate((np.arange(0,len_sig//2+1) , np.arange(-((len_sig+1)//2)+1,0))) * fs / len_sig 91 # omega *= fs / len_sig 92 93 # Warning: this code was dogmatically translated from MatLab repo 94 tolerance = 0.5 95 mincenterfreq = 2*tolerance*np.sqrt(sigma2)*fs*xi / len_sig 96 maxcenterfreq = fs*xi/(xi+tolerance/np.sqrt(sigma2)) # Shouldn't this be divided by two because of aliasing? 97 nyquist = fs / 2 98 maxcenterfreq = min(maxcenterfreq,nyquist) 99 logger.debug(f"fs = {fs}") 100 logger.debug(f"freqs = {freqs}") 101 logger.debug(f"\n\tLowest freq = {min(freqs)}\n\tHighest freq = {max(freqs)}") 102 logger.debug(f"\n\tmincenterfreq = {mincenterfreq}\n\tmaxcenterfreq = {maxcenterfreq}") 103 104 s_arr = xi / freqs 105 minscale = xi / maxcenterfreq 106 maxscale = xi / mincenterfreq 107 # reject frequencies that are outside the given scale 108 if ((s_arr >= minscale) | (s_arr <= maxscale)).any(): 109 warnings.warn("Frequencies are not between minscale and maxscale.") 110 111 n_freqs = len(freqs) 112 # np.complex64 is numpy's coarsest complex numpy type 113 wt = np.zeros((len_sig,n_freqs),dtype=np.complex64) 114 115 for idx,s in enumerate(s_arr): 116 freq = (s * omega - xi) 117 psi = np.power(4*np.pi*sigma2,0.25) * np.sqrt(s) * np.exp(-sigma2/2 * freq*freq) 118 wt[:,idx] = ifft(fft(signal) * psi) 119 120 return np.squeeze(wt) # turns 2d into 1d IFF single freq
Computes one or multiple wavelet transforms of the input signal.
Follows awt_freqlist.m from the buzzcode repository.
Parameters
signal : np.ndarray
The input signal. Only accepts 1D signals.
fs : int or float
The sampling frequency.
freqs : list or float
The frequency or list of frequencies to compute.
xi : int
The number of oscillations parameter, only needed for Gabor wavelet.
Returns
np.ndarray
A numpy array of dim (len(freqs),len(signal))
135def make_wavelet_bank( 136 edf_fname:str, 137 fio_ops:dict, 138 data_ops:dict): 139 """Computes and saves a wavelet decomposition of each channel. 140 141 Uses dictionaries loaded from user defined options from Options.toml 142 (options_filepath) file to compute the Gabor wavelet decomposition 143 of the raw signals in the provided edf file (edf_fname). 144 This function doesn't return anything, but reads and writes to disk. 145 146 The signals are scaled before saving to hard disk, this is to mitigate 147 quantization effects, since we are saving our data as int16. 148 149 - Reads edf raw signal specified by edf_fname (and fio_ops params) 150 - Iterates through each channel, computing wavelet convolutions 151 for frequencies in a range specified by data_ops 152 - Saves output binaries, one binary file for each hardware channel, 153 all the frequencies are saved according to the below order 154 155 Binaries array flattening convention: 156 - Read 'sn' as 'sample number n' 157 - A is for Amplitude (=Power), and P is for Phase 158 - K is the index of the last frequency (= num of freqs - 1) 159 [raw_s0,freq00_A_s0,freq00_P_s0,freq01_A_s0,freq01_P_s0,...,freqk_A_s0, 160 freqK_P_s0,raw_s1,freq00_A_s1,freq00_P_s1,...,freqK_A_s1,freqK_P_s1,... 161 ... 162 raw_sn,freq00_A_sn,freq00_P_sn,freq01_A_s0,...,freqK_P_sn] 163 164 Note: it is important the above convention is respected because this is 165 how the binary_io tools read the files. It's the same convention as the 166 MatLab suit. 167 168 Parameters 169 ---------- 170 171 `edf_fname` 172 The name of the '.edf' raw data file. We look for all edf files 173 in fio_ops.RAW_DATA_PATH from Options.toml 174 175 `fio_ops : dict` 176 The fio parameters defined in the Options.toml config file. 177 178 `data_ops : dict` 179 Data parameters from the Options.toml config file. 180 The Options.toml config file contains user defined parameters, 181 including fio params, data and data-processing params, ML 182 model hyper-params, and training params. 183 """ 184 assert len(edf_fname.split("."))==2, f"There can be no dots in the base file name, {edf_fname}" 185 basename,ext = edf_fname.split(".") 186 assert ext == "edf", f"Incorrect file format, expected .edf, got {edf_fname}" 187 188 # Unpack File IO constants 189 RAW_DATA_PATH = fio_ops["RAW_DATA_PATH"] 190 WAVELET_BINARIES_PATH = fio_ops["WAVELET_BINARIES_PATH"] 191 # Unpack Data config constants 192 FS = data_ops["FS"] 193 NUM_FREQ = data_ops["NUM_FREQ"] 194 LOW_FREQ = data_ops["LOW_FREQ"] 195 HIGH_FREQ = data_ops["HIGH_FREQ"] 196 SPACING = data_ops["SPACING"] 197 ZSCORE_POWER = data_ops["ZSCORE_POWER"] # bool 198 SCALE_RAW = data_ops["SCALE_RAW"] 199 SCALE_PHASE = data_ops["SCALE_PHASE"] 200 SCALE_POWER = data_ops["SCALE_POWER"] 201 N_CHAN_RAW = data_ops["N_CHAN_RAW"] 202 TS_FEATURES = data_ops["TS_FEATURES"] 203 # Define features space (based on num, lo, high-freq) 204 FREQS = np.logspace(np.log10(LOW_FREQ), np.log10(HIGH_FREQ), NUM_FREQ) 205 206 ### IO Checks 207 # Check edf file and wavelet binaries directory exist 208 edf_path = os.path.join(RAW_DATA_PATH , edf_fname) 209 assert os.path.exists(edf_path), f"Invalid edf file path. Make sure edf file exists and is inside of the {RAW_DATA_PATH} directory." 210 assert os.path.exists(WAVELET_BINARIES_PATH), f"Invalid path {WAVELET_BINARIES_PATH}\nCheck your configuration file Options.toml" 211 # Check if the cache folder for binaries exists, delete it 212 cache_dir_path = os.path.join(WAVELET_BINARIES_PATH, "cache") 213 if os.path.exists(cache_dir_path): 214 logger.info("Deleting cache") 215 shutil.rmtree(cache_dir_path) 216 else: logger.debug("No cache, proceeding") 217 218 # Read edf file and loop through each channel one at a time 219 for channel in range(N_CHAN_RAW): 220 # TODO: simplify and sacrifice some of the checks in this untidy loop 221 # in the name of tidyness and readability and the good lord 222 os.mkdir(cache_dir_path) # Create the cache directory 223 sig = None 224 with pyedflib.EdfReader(edf_path) as f: 225 assert N_CHAN_RAW == f.signals_in_file, f"N_CHAN_RAW={N_CHAN_RAW} incompatible with detected number of channels in file ={f.signals_in_file}" 226 sig = f.readSignal(channel) # .astype("int16") # this would quantize signal badly 227 _rms = np.sqrt(np.power(sig,2).mean()) # for debugging only 228 logger.debug(f"Sample of quantized sig {sig.astype('int16')[:10]}") 229 logger.debug(f"Root Mean Square raw signal = {_rms:.3f}") 230 assert sig.shape==(f.getNSamples()[0],) # make sure exists and right shape 231 232 # Save raw channel data as .dat binary 233 cached_bin_fname_raw = f"{basename}_ch_{str(channel).zfill(3)}_raw.dat" 234 logger.debug(f"cache_dir_path = '{cache_dir_path}'") 235 cached_binary_raw_path = utils.fmt_binary_cache_wav_path( 236 cache_dir_path, 237 basename, 238 channel, 239 "RAW") 240 (sig * SCALE_RAW).astype("int16").tofile(cached_binary_raw_path) 241 # TODO: delete below comment 242 # (sig * SCALE_RAW).astype("int16").tofile(os.path.join(cache_dir_path, cached_bin_fname_raw)) 243 244 print(f"Computing {NUM_FREQ} Gabor wavelet convolutions for channel {channel}.") 245 # Loop through each frequency, convolve with Gabor wavelet 246 for f_idx,freq in tqdm(enumerate(FREQS)): 247 # Define cache filepath for this channel & freq 248 cached_binary_amp_path,cached_binary_phase_path = utils.fmt_binary_cache_wav_path( 249 cache_dir_path, 250 basename, 251 channel, 252 "AMP-PHASE", 253 f_idx) 254 255 # Convolve signal with the the wavelet, similar to awt_freqs 256 wt = compute_wavelet_gabor(signal=sig,fs=FS,freqs=[freq]) 257 258 # TODO: do we need this condition? Won't we always use wavelet power? 259 # if we don't use wavelet power it might break some of the code 260 # Conditionally Zscore, re-scale, and save the power 261 if "wavelet_power" in TS_FEATURES: 262 wt_power = np.abs(wt) # deep copy 263 if ZSCORE_POWER==True: 264 wt_power = zscore(np.abs(wt)) 265 # Comment in Options.toml: SCALE_POWER should be smaller if no zscore 266 logger.debug(f"wt_power.dtype={wt_power.dtype}") 267 wt_power = (wt_power * SCALE_POWER).astype("int16") 268 wt_power.tofile(cached_binary_amp_path) 269 270 # Conditionally re-scale and save the phase 271 if "wavelet_phase" in TS_FEATURES: 272 wt_phase = (np.arctan(np.real(wt) / np.imag(wt)) * SCALE_PHASE).astype("int16") 273 wt_phase.tofile(cached_binary_phase_path) # , format="int16") 274 logger.info("Finished computing wavelet transforms for channel={channel}.") 275 276 # Check all the dat files in the cache match with the regex—– 277 # this is a test to make sure our cached folder is not corrupted 278 # Filenames must not contain any dots! 279 # Regexp must match all the Amplicude (=Power) and Phase transforms 280 # as well as the raw .dat binary. 281 regex = f"^{basename}_ch_{str(channel).zfill(3)}_(freqidx_\d\d|0RAW)(_A|_P|)$" 282 _assert_all_ext_type_match_regexp( 283 directory = cache_dir_path, 284 extension = "dat", 285 regexp_base = regex) 286 287 ### Merge all the cached frequency .dat files into a single one. 288 # Sort the cached binaries, this will put lower f_idx first and 289 # alternate A before P, exact same order they are created 290 # All of the files in the cache have the same channel number (checked above) 291 sorted_cache_binaries = [i for i in os.listdir(cache_dir_path) if i[-4:]==".dat"] 292 sorted_cache_binaries.sort() # The are already be sorted, this line 293 # is mainly to aid the reader 294 # Sanity check assert 295 assert "RAW" in sorted_cache_binaries[0], "Error in sorting the binaries, first file must be the raw channel" 296 scb_paths = [os.path.join(cache_dir_path,i) for i in sorted_cache_binaries] 297 merge_dats( 298 fpaths_in = scb_paths, 299 dir_out = WAVELET_BINARIES_PATH, 300 fname_out = f"{basename}_ch_{str(channel).zfill(3)}.dat") 301 302 # Delete the cached single-channel binary files 303 shutil.rmtree(cache_dir_path) 304 return
Computes and saves a wavelet decomposition of each channel.
Uses dictionaries loaded from user defined options from Options.toml (options_filepath) file to compute the Gabor wavelet decomposition of the raw signals in the provided edf file (edf_fname). This function doesn't return anything, but reads and writes to disk.
The signals are scaled before saving to hard disk, this is to mitigate quantization effects, since we are saving our data as int16.
- Reads edf raw signal specified by edf_fname (and fio_ops params)
- Iterates through each channel, computing wavelet convolutions for frequencies in a range specified by data_ops
- Saves output binaries, one binary file for each hardware channel, all the frequencies are saved according to the below order
Binaries array flattening convention:
- Read 'sn' as 'sample number n'
- A is for Amplitude (=Power), and P is for Phase
- K is the index of the last frequency (= num of freqs - 1) [raw_s0,freq00_A_s0,freq00_P_s0,freq01_A_s0,freq01_P_s0,...,freqk_A_s0, freqK_P_s0,raw_s1,freq00_A_s1,freq00_P_s1,...,freqK_A_s1,freqK_P_s1,... ... raw_sn,freq00_A_sn,freq00_P_sn,freq01_A_s0,...,freqK_P_sn]
Note: it is important the above convention is respected because this is how the binary_io tools read the files. It's the same convention as the MatLab suit.
Parameters
edf_fname
The name of the '.edf' raw data file. We look for all edf files
in fio_ops.RAW_DATA_PATH from Options.toml
fio_ops : dict
The fio parameters defined in the Options.toml config file.
data_ops : dict
Data parameters from the Options.toml config file.
The Options.toml config file contains user defined parameters,
including fio params, data and data-processing params, ML
model hyper-params, and training params.
307def make_wavelet_bank_all(options_path="Options.toml"): 308 # Unpack parameters and user-defined constants 309 fio_ops,data_ops = load_fio_ops_and_data_ops(options_path) 310 # Bind the consts we use to local vars for readability 311 RAW_DATA_PATH = fio_ops["RAW_DATA_PATH"] 312 313 edf_files = [i for i in os.listdir(RAW_DATA_PATH) if os.path.splitext(i)[1]==".edf"] 314 print(f"Make wavelet bank all, convolving {len(edf_files)} edf files.") 315 for edf_fname in edf_files: 316 logger.info(f"Running make_wavelet_bank on {edf_fname}") 317 # Convolve with wavelets and write binary files 318 make_wavelet_bank(edf_fname, fio_ops, data_ops) 319 320 return