ihkapy.sm_make_wavelet_bank

Author statement

This file is based off of sm_getPowerPerChannel.m written by Samuel McKenzie, and awt_freqlist.m and Maureen Clerc, Christian Benar, october 2007 Translated and adapted to python in May 2022 by Stephen Fay dcxstephen@gmail.com

  1"""
  2Author statement
  3----------------
  4This file is based off of sm_getPowerPerChannel.m written by Samuel McKenzie, 
  5and awt_freqlist.m and Maureen Clerc, Christian Benar, october 2007 
  6Translated and adapted to python in May 2022 by Stephen Fay dcxstephen@gmail.com
  7"""
  8
  9# TODO: standardize name 'channel' to 'raw_ch_idx'
 10# TODO: standardize naming convention 'amp' and 'power' are used as synonyms
 11#       I think 'amp' is better because it looks less like 'phase', easy distinguish
 12
 13from ihkapy.fileio.binary_io import merge_dats # local dependency
 14from ihkapy.fileio import utils
 15from ihkapy.fileio.options_io import load_fio_ops_and_data_ops
 16import os                       # I/O
 17import shutil                   # I/O
 18from tqdm import tqdm           # Progressbar
 19import logging                  # For debugging and following code
 20import warnings                 # Bulletproof code
 21import re                       # Regexp library, to bulletproof code
 22import pyedflib                 # Read from edf files | this is "terrible in MatLab", look into it, apparently it loads everything into RAM
 23import numpy as np              # Array manipulation, Scientific computing
 24from numpy.fft import fft, ifft # Signal processing
 25from scipy.stats import zscore  # Signal processing
 26
 27
 28# Init logger and set the logging level
 29logging.basicConfig(level=logging.DEBUG)
 30logger = logging.getLogger(__name__)
 31
 32# For variables containing strings of with absolute path, we explicitly 
 33# include the word "path" in the variable name. For those with relative 
 34# or leaf paths, we do not put "path" in the name. 
 35
 36# TODO: implement for Lusin and Sombrero wavelets too
 37# Note, Lusin wasn't implemented in Matlab, and pipeline only used Gabor
 38# 
 39# compute_wavelet_gabor corresponds to awt_freqlist in the MatLab code
 40# 
 41# %  Maureen Clerc, Christian Benar, october 2007
 42# %  modified from awt from wavelab toolbox
 43# 
 44# % History of changes
 45# % 1/11/2007: (cgb) psi_array: output in complex 
 46# % 3/06/2008: (cgb) init of psi_array to size of wt
 47#   3/05/2022: SF translated awt_freqlist to 
 48
 49def compute_wavelet_gabor(
 50        signal: np.ndarray,
 51        fs: int or float,
 52        freqs: list or float,
 53        xi: int = 5 # only needed for Gabor
 54        ) -> np.ndarray: 
 55    """Computes one or multiple wavelet transforms of the input signal.
 56
 57    Follows awt_freqlist.m from the buzzcode repository.
 58
 59    Parameters
 60    ----------
 61    `signal : np.ndarray`
 62        The input signal. Only accepts 1D signals. 
 63
 64    `fs : int or float`
 65        The sampling frequency. 
 66
 67    `freqs : list or float`
 68        The frequency or list of frequencies to compute. 
 69
 70    `xi : int`
 71        The number of oscillations parameter, only needed for Gabor wavelet.
 72
 73    Returns
 74    -------
 75    `np.ndarray`
 76        A numpy array of dim (len(freqs),len(signal))
 77    """
 78    # Make sure all types are correct
 79    if isinstance(freqs, float) or isinstance(freqs, int): freqs = [freqs]
 80    freqs = np.asarray(freqs)
 81    signal = np.asarray(signal)
 82    assert fs > 0 and (isinstance(fs, float) or isinstance(fs, int))
 83    assert signal.ndim == 1, "Must be single dim signal" 
 84    # TODO: implement multi-dim and remove above assertion
 85    # (not crucial because we don't (yet) use that in pipeline)
 86
 87    (len_sig,) = signal.shape
 88    sigma2 = 1
 89    omega = np.concatenate((np.arange(0,len_sig//2+1) , np.arange(-((len_sig+1)//2)+1,0))) * fs / len_sig
 90    # omega *= fs / len_sig
 91
 92    # Warning: this code was dogmatically translated from MatLab repo 
 93    tolerance = 0.5
 94    mincenterfreq = 2*tolerance*np.sqrt(sigma2)*fs*xi / len_sig
 95    maxcenterfreq = fs*xi/(xi+tolerance/np.sqrt(sigma2)) # Shouldn't this be divided by two because of aliasing? 
 96    nyquist = fs / 2
 97    maxcenterfreq = min(maxcenterfreq,nyquist)
 98    logger.debug(f"fs = {fs}")
 99    logger.debug(f"freqs = {freqs}")
100    logger.debug(f"\n\tLowest freq = {min(freqs)}\n\tHighest freq = {max(freqs)}")
101    logger.debug(f"\n\tmincenterfreq = {mincenterfreq}\n\tmaxcenterfreq = {maxcenterfreq}")
102
103    s_arr = xi / freqs
104    minscale = xi / maxcenterfreq
105    maxscale = xi / mincenterfreq
106    # reject frequencies that are outside the given scale
107    if ((s_arr >= minscale) | (s_arr <= maxscale)).any():
108        warnings.warn("Frequencies are not between minscale and maxscale.")
109
110    n_freqs = len(freqs)
111    # np.complex64 is numpy's coarsest complex numpy type
112    wt = np.zeros((len_sig,n_freqs),dtype=np.complex64) 
113    
114    for idx,s in enumerate(s_arr):
115        freq = (s * omega - xi)
116        psi = np.power(4*np.pi*sigma2,0.25) * np.sqrt(s) * np.exp(-sigma2/2 * freq*freq)
117        wt[:,idx] = ifft(fft(signal) * psi)
118
119    return np.squeeze(wt) # turns 2d into 1d IFF single freq 
120
121
122# Helper, test to make sure our cache folder is not corrupted
123def _assert_all_ext_type_match_regexp( 
124        directory: str,
125        extension: str,
126        regexp_base: str):
127    for fname in os.listdir(directory):
128        base,ext = os.path.splitext(fname)
129        if ext == extension: assert bool(re.search(regexp_base,base))
130    logger.debug(f"Test passed: all '{extension}' files in {directory} match the regexp:\n{regexp_base}")
131    return 
132
133
134def make_wavelet_bank(
135        edf_fname:str,
136        fio_ops:dict,
137        data_ops:dict): 
138    """Computes and saves a wavelet decomposition of each channel. 
139
140    Uses dictionaries loaded from user defined options from Options.toml 
141    (options_filepath) file to compute the Gabor wavelet decomposition 
142    of the raw signals in the provided edf file (edf_fname). 
143    This function doesn't return anything, but reads and writes to disk. 
144
145    The signals are scaled before saving to hard disk, this is to mitigate
146    quantization effects, since we are saving our data as int16. 
147
148    - Reads edf raw signal specified by edf_fname (and fio_ops params)
149    - Iterates through each channel, computing wavelet convolutions
150        for frequencies in a range specified by data_ops
151    - Saves output binaries, one binary file for each hardware channel,
152        all the frequencies are saved according to the below order
153
154    Binaries array flattening convention: 
155    - Read 'sn' as 'sample number n'
156    - A is for Amplitude (=Power), and P is for Phase
157    - K is the index of the last frequency (= num of freqs - 1)
158    [raw_s0,freq00_A_s0,freq00_P_s0,freq01_A_s0,freq01_P_s0,...,freqk_A_s0,
159    freqK_P_s0,raw_s1,freq00_A_s1,freq00_P_s1,...,freqK_A_s1,freqK_P_s1,...
160    ...
161    raw_sn,freq00_A_sn,freq00_P_sn,freq01_A_s0,...,freqK_P_sn]
162
163    Note: it is important the above convention is respected because this is
164    how the binary_io tools read the files. It's the same convention as the 
165    MatLab suit. 
166 
167    Parameters
168    ----------
169
170    `edf_fname`
171        The name of the '.edf' raw data file. We look for all edf files 
172        in fio_ops.RAW_DATA_PATH from Options.toml
173
174    `fio_ops : dict`
175        The fio parameters defined in the Options.toml config file.
176
177    `data_ops : dict`
178        Data parameters from the Options.toml config file. 
179        The Options.toml config file contains user defined parameters, 
180        including fio params, data and data-processing params, ML
181        model hyper-params, and training params. 
182    """
183    assert len(edf_fname.split("."))==2, f"There can be no dots in the base file name, {edf_fname}"
184    basename,ext = edf_fname.split(".") 
185    assert ext == "edf", f"Incorrect file format, expected .edf, got {edf_fname}"
186
187    # Unpack File IO constants
188    RAW_DATA_PATH         = fio_ops["RAW_DATA_PATH"]
189    WAVELET_BINARIES_PATH = fio_ops["WAVELET_BINARIES_PATH"]
190    # Unpack Data config constants
191    FS           = data_ops["FS"]
192    NUM_FREQ     = data_ops["NUM_FREQ"]
193    LOW_FREQ     = data_ops["LOW_FREQ"]
194    HIGH_FREQ    = data_ops["HIGH_FREQ"]
195    SPACING      = data_ops["SPACING"]
196    ZSCORE_POWER = data_ops["ZSCORE_POWER"] # bool
197    SCALE_RAW    = data_ops["SCALE_RAW"]
198    SCALE_PHASE  = data_ops["SCALE_PHASE"]
199    SCALE_POWER  = data_ops["SCALE_POWER"]
200    N_CHAN_RAW   = data_ops["N_CHAN_RAW"]
201    TS_FEATURES  = data_ops["TS_FEATURES"]
202    # Define features space (based on num, lo, high-freq)
203    FREQS = np.logspace(np.log10(LOW_FREQ), np.log10(HIGH_FREQ), NUM_FREQ)
204
205    ### IO Checks
206    # Check edf file and wavelet binaries directory exist
207    edf_path = os.path.join(RAW_DATA_PATH , edf_fname)
208    assert os.path.exists(edf_path), f"Invalid edf file path. Make sure edf file exists and is inside of the {RAW_DATA_PATH} directory."
209    assert os.path.exists(WAVELET_BINARIES_PATH), f"Invalid path {WAVELET_BINARIES_PATH}\nCheck your configuration file Options.toml"
210    # Check if the cache folder for binaries exists, delete it
211    cache_dir_path = os.path.join(WAVELET_BINARIES_PATH, "cache")
212    if os.path.exists(cache_dir_path):
213        logger.info("Deleting cache")
214        shutil.rmtree(cache_dir_path)
215    else: logger.debug("No cache, proceeding")
216
217    # Read edf file and loop through each channel one at a time
218    for channel in range(N_CHAN_RAW):
219        # TODO: simplify and sacrifice some of the checks in this untidy loop
220        #       in the name of tidyness and readability and the good lord
221        os.mkdir(cache_dir_path) # Create the cache directory
222        sig = None
223        with pyedflib.EdfReader(edf_path) as f:
224            assert N_CHAN_RAW == f.signals_in_file, f"N_CHAN_RAW={N_CHAN_RAW} incompatible with detected number of channels in file ={f.signals_in_file}"
225            sig = f.readSignal(channel) # .astype("int16") # this would quantize signal badly
226            _rms = np.sqrt(np.power(sig,2).mean()) # for debugging only
227            logger.debug(f"Sample of quantized sig {sig.astype('int16')[:10]}")
228            logger.debug(f"Root Mean Square raw signal = {_rms:.3f}")
229        assert sig.shape==(f.getNSamples()[0],) # make sure exists and right shape
230
231        # Save raw channel data as .dat binary
232        cached_bin_fname_raw = f"{basename}_ch_{str(channel).zfill(3)}_raw.dat"
233        logger.debug(f"cache_dir_path = '{cache_dir_path}'")
234        cached_binary_raw_path = utils.fmt_binary_cache_wav_path(
235                cache_dir_path, 
236                basename, 
237                channel, 
238                "RAW")
239        (sig * SCALE_RAW).astype("int16").tofile(cached_binary_raw_path) 
240        # TODO: delete below comment
241        # (sig * SCALE_RAW).astype("int16").tofile(os.path.join(cache_dir_path, cached_bin_fname_raw)) 
242
243        print(f"Computing {NUM_FREQ} Gabor wavelet convolutions for channel {channel}.")
244        # Loop through each frequency, convolve with Gabor wavelet
245        for f_idx,freq in tqdm(enumerate(FREQS)):
246            # Define cache filepath for this channel & freq
247            cached_binary_amp_path,cached_binary_phase_path = utils.fmt_binary_cache_wav_path(
248                    cache_dir_path,
249                    basename,
250                    channel,
251                    "AMP-PHASE",
252                    f_idx)
253            
254            # Convolve signal with the the wavelet, similar to awt_freqs
255            wt = compute_wavelet_gabor(signal=sig,fs=FS,freqs=[freq])
256
257            # TODO: do we need this condition? Won't we always use wavelet power?
258            #       if we don't use wavelet power it might break some of the code
259            # Conditionally Zscore, re-scale, and save the power
260            if "wavelet_power" in TS_FEATURES:
261                wt_power = np.abs(wt) # deep copy
262                if ZSCORE_POWER==True:
263                    wt_power = zscore(np.abs(wt))
264                # Comment in Options.toml: SCALE_POWER should be smaller if no zscore
265                logger.debug(f"wt_power.dtype={wt_power.dtype}")
266                wt_power = (wt_power * SCALE_POWER).astype("int16")
267                wt_power.tofile(cached_binary_amp_path)
268
269            # Conditionally re-scale and save the phase
270            if "wavelet_phase" in TS_FEATURES:
271                wt_phase = (np.arctan(np.real(wt) / np.imag(wt)) * SCALE_PHASE).astype("int16")
272                wt_phase.tofile(cached_binary_phase_path) # , format="int16")
273        logger.info("Finished computing wavelet transforms for channel={channel}.")
274
275        # Check all the dat files in the cache match with the regex—–
276        # this is a test to make sure our cached folder is not corrupted
277        # Filenames must not contain any dots!
278        # Regexp must match all the Amplicude (=Power) and Phase transforms
279        # as well as the raw .dat binary. 
280        regex = f"^{basename}_ch_{str(channel).zfill(3)}_(freqidx_\d\d|0RAW)(_A|_P|)$"
281        _assert_all_ext_type_match_regexp( 
282                directory = cache_dir_path,
283                extension = "dat",
284                regexp_base = regex)
285
286        ### Merge all the cached frequency .dat files into a single one.
287        # Sort the cached binaries, this will put lower f_idx first and 
288        # alternate A before P, exact same order they are created
289        # All of the files in the cache have the same channel number (checked above)
290        sorted_cache_binaries = [i for i in os.listdir(cache_dir_path) if i[-4:]==".dat"]
291        sorted_cache_binaries.sort() # The are already be sorted, this line
292                                     # is mainly to aid the reader
293        # Sanity check assert
294        assert "RAW" in sorted_cache_binaries[0], "Error in sorting the binaries, first file must be the raw channel"
295        scb_paths = [os.path.join(cache_dir_path,i) for i in sorted_cache_binaries]
296        merge_dats(
297                fpaths_in = scb_paths,
298                dir_out = WAVELET_BINARIES_PATH,
299                fname_out = f"{basename}_ch_{str(channel).zfill(3)}.dat")
300
301        # Delete the cached single-channel binary files
302        shutil.rmtree(cache_dir_path) 
303    return
304
305
306def make_wavelet_bank_all(options_path="Options.toml"):
307    # Unpack parameters and user-defined constants
308    fio_ops,data_ops = load_fio_ops_and_data_ops(options_path)
309    # Bind the consts we use to local vars for readability
310    RAW_DATA_PATH = fio_ops["RAW_DATA_PATH"]
311
312    edf_files = [i for i in os.listdir(RAW_DATA_PATH) if os.path.splitext(i)[1]==".edf"] 
313    print(f"Make wavelet bank all, convolving {len(edf_files)} edf files.")
314    for edf_fname in edf_files:
315        logger.info(f"Running make_wavelet_bank on {edf_fname}")
316        # Convolve with wavelets and write binary files
317        make_wavelet_bank(edf_fname, fio_ops, data_ops)
318
319    return
320
321
322# Next block runs only when you run this file directly, not on import
323if __name__ == "__main__":
324    # make_wavelet_bank_all(options_path="Options.toml")
325    make_wavelet_bank_all(options_path="Options_test.toml")
326    
327
328    
def compute_wavelet_gabor( signal: numpy.ndarray, fs: int, freqs: list, xi: int = 5) -> numpy.ndarray:
 50def compute_wavelet_gabor(
 51        signal: np.ndarray,
 52        fs: int or float,
 53        freqs: list or float,
 54        xi: int = 5 # only needed for Gabor
 55        ) -> np.ndarray: 
 56    """Computes one or multiple wavelet transforms of the input signal.
 57
 58    Follows awt_freqlist.m from the buzzcode repository.
 59
 60    Parameters
 61    ----------
 62    `signal : np.ndarray`
 63        The input signal. Only accepts 1D signals. 
 64
 65    `fs : int or float`
 66        The sampling frequency. 
 67
 68    `freqs : list or float`
 69        The frequency or list of frequencies to compute. 
 70
 71    `xi : int`
 72        The number of oscillations parameter, only needed for Gabor wavelet.
 73
 74    Returns
 75    -------
 76    `np.ndarray`
 77        A numpy array of dim (len(freqs),len(signal))
 78    """
 79    # Make sure all types are correct
 80    if isinstance(freqs, float) or isinstance(freqs, int): freqs = [freqs]
 81    freqs = np.asarray(freqs)
 82    signal = np.asarray(signal)
 83    assert fs > 0 and (isinstance(fs, float) or isinstance(fs, int))
 84    assert signal.ndim == 1, "Must be single dim signal" 
 85    # TODO: implement multi-dim and remove above assertion
 86    # (not crucial because we don't (yet) use that in pipeline)
 87
 88    (len_sig,) = signal.shape
 89    sigma2 = 1
 90    omega = np.concatenate((np.arange(0,len_sig//2+1) , np.arange(-((len_sig+1)//2)+1,0))) * fs / len_sig
 91    # omega *= fs / len_sig
 92
 93    # Warning: this code was dogmatically translated from MatLab repo 
 94    tolerance = 0.5
 95    mincenterfreq = 2*tolerance*np.sqrt(sigma2)*fs*xi / len_sig
 96    maxcenterfreq = fs*xi/(xi+tolerance/np.sqrt(sigma2)) # Shouldn't this be divided by two because of aliasing? 
 97    nyquist = fs / 2
 98    maxcenterfreq = min(maxcenterfreq,nyquist)
 99    logger.debug(f"fs = {fs}")
100    logger.debug(f"freqs = {freqs}")
101    logger.debug(f"\n\tLowest freq = {min(freqs)}\n\tHighest freq = {max(freqs)}")
102    logger.debug(f"\n\tmincenterfreq = {mincenterfreq}\n\tmaxcenterfreq = {maxcenterfreq}")
103
104    s_arr = xi / freqs
105    minscale = xi / maxcenterfreq
106    maxscale = xi / mincenterfreq
107    # reject frequencies that are outside the given scale
108    if ((s_arr >= minscale) | (s_arr <= maxscale)).any():
109        warnings.warn("Frequencies are not between minscale and maxscale.")
110
111    n_freqs = len(freqs)
112    # np.complex64 is numpy's coarsest complex numpy type
113    wt = np.zeros((len_sig,n_freqs),dtype=np.complex64) 
114    
115    for idx,s in enumerate(s_arr):
116        freq = (s * omega - xi)
117        psi = np.power(4*np.pi*sigma2,0.25) * np.sqrt(s) * np.exp(-sigma2/2 * freq*freq)
118        wt[:,idx] = ifft(fft(signal) * psi)
119
120    return np.squeeze(wt) # turns 2d into 1d IFF single freq 

Computes one or multiple wavelet transforms of the input signal.

Follows awt_freqlist.m from the buzzcode repository.

Parameters

signal : np.ndarray The input signal. Only accepts 1D signals.

fs : int or float The sampling frequency.

freqs : list or float The frequency or list of frequencies to compute.

xi : int The number of oscillations parameter, only needed for Gabor wavelet.

Returns

np.ndarray A numpy array of dim (len(freqs),len(signal))

def make_wavelet_bank(edf_fname: str, fio_ops: dict, data_ops: dict)
135def make_wavelet_bank(
136        edf_fname:str,
137        fio_ops:dict,
138        data_ops:dict): 
139    """Computes and saves a wavelet decomposition of each channel. 
140
141    Uses dictionaries loaded from user defined options from Options.toml 
142    (options_filepath) file to compute the Gabor wavelet decomposition 
143    of the raw signals in the provided edf file (edf_fname). 
144    This function doesn't return anything, but reads and writes to disk. 
145
146    The signals are scaled before saving to hard disk, this is to mitigate
147    quantization effects, since we are saving our data as int16. 
148
149    - Reads edf raw signal specified by edf_fname (and fio_ops params)
150    - Iterates through each channel, computing wavelet convolutions
151        for frequencies in a range specified by data_ops
152    - Saves output binaries, one binary file for each hardware channel,
153        all the frequencies are saved according to the below order
154
155    Binaries array flattening convention: 
156    - Read 'sn' as 'sample number n'
157    - A is for Amplitude (=Power), and P is for Phase
158    - K is the index of the last frequency (= num of freqs - 1)
159    [raw_s0,freq00_A_s0,freq00_P_s0,freq01_A_s0,freq01_P_s0,...,freqk_A_s0,
160    freqK_P_s0,raw_s1,freq00_A_s1,freq00_P_s1,...,freqK_A_s1,freqK_P_s1,...
161    ...
162    raw_sn,freq00_A_sn,freq00_P_sn,freq01_A_s0,...,freqK_P_sn]
163
164    Note: it is important the above convention is respected because this is
165    how the binary_io tools read the files. It's the same convention as the 
166    MatLab suit. 
167 
168    Parameters
169    ----------
170
171    `edf_fname`
172        The name of the '.edf' raw data file. We look for all edf files 
173        in fio_ops.RAW_DATA_PATH from Options.toml
174
175    `fio_ops : dict`
176        The fio parameters defined in the Options.toml config file.
177
178    `data_ops : dict`
179        Data parameters from the Options.toml config file. 
180        The Options.toml config file contains user defined parameters, 
181        including fio params, data and data-processing params, ML
182        model hyper-params, and training params. 
183    """
184    assert len(edf_fname.split("."))==2, f"There can be no dots in the base file name, {edf_fname}"
185    basename,ext = edf_fname.split(".") 
186    assert ext == "edf", f"Incorrect file format, expected .edf, got {edf_fname}"
187
188    # Unpack File IO constants
189    RAW_DATA_PATH         = fio_ops["RAW_DATA_PATH"]
190    WAVELET_BINARIES_PATH = fio_ops["WAVELET_BINARIES_PATH"]
191    # Unpack Data config constants
192    FS           = data_ops["FS"]
193    NUM_FREQ     = data_ops["NUM_FREQ"]
194    LOW_FREQ     = data_ops["LOW_FREQ"]
195    HIGH_FREQ    = data_ops["HIGH_FREQ"]
196    SPACING      = data_ops["SPACING"]
197    ZSCORE_POWER = data_ops["ZSCORE_POWER"] # bool
198    SCALE_RAW    = data_ops["SCALE_RAW"]
199    SCALE_PHASE  = data_ops["SCALE_PHASE"]
200    SCALE_POWER  = data_ops["SCALE_POWER"]
201    N_CHAN_RAW   = data_ops["N_CHAN_RAW"]
202    TS_FEATURES  = data_ops["TS_FEATURES"]
203    # Define features space (based on num, lo, high-freq)
204    FREQS = np.logspace(np.log10(LOW_FREQ), np.log10(HIGH_FREQ), NUM_FREQ)
205
206    ### IO Checks
207    # Check edf file and wavelet binaries directory exist
208    edf_path = os.path.join(RAW_DATA_PATH , edf_fname)
209    assert os.path.exists(edf_path), f"Invalid edf file path. Make sure edf file exists and is inside of the {RAW_DATA_PATH} directory."
210    assert os.path.exists(WAVELET_BINARIES_PATH), f"Invalid path {WAVELET_BINARIES_PATH}\nCheck your configuration file Options.toml"
211    # Check if the cache folder for binaries exists, delete it
212    cache_dir_path = os.path.join(WAVELET_BINARIES_PATH, "cache")
213    if os.path.exists(cache_dir_path):
214        logger.info("Deleting cache")
215        shutil.rmtree(cache_dir_path)
216    else: logger.debug("No cache, proceeding")
217
218    # Read edf file and loop through each channel one at a time
219    for channel in range(N_CHAN_RAW):
220        # TODO: simplify and sacrifice some of the checks in this untidy loop
221        #       in the name of tidyness and readability and the good lord
222        os.mkdir(cache_dir_path) # Create the cache directory
223        sig = None
224        with pyedflib.EdfReader(edf_path) as f:
225            assert N_CHAN_RAW == f.signals_in_file, f"N_CHAN_RAW={N_CHAN_RAW} incompatible with detected number of channels in file ={f.signals_in_file}"
226            sig = f.readSignal(channel) # .astype("int16") # this would quantize signal badly
227            _rms = np.sqrt(np.power(sig,2).mean()) # for debugging only
228            logger.debug(f"Sample of quantized sig {sig.astype('int16')[:10]}")
229            logger.debug(f"Root Mean Square raw signal = {_rms:.3f}")
230        assert sig.shape==(f.getNSamples()[0],) # make sure exists and right shape
231
232        # Save raw channel data as .dat binary
233        cached_bin_fname_raw = f"{basename}_ch_{str(channel).zfill(3)}_raw.dat"
234        logger.debug(f"cache_dir_path = '{cache_dir_path}'")
235        cached_binary_raw_path = utils.fmt_binary_cache_wav_path(
236                cache_dir_path, 
237                basename, 
238                channel, 
239                "RAW")
240        (sig * SCALE_RAW).astype("int16").tofile(cached_binary_raw_path) 
241        # TODO: delete below comment
242        # (sig * SCALE_RAW).astype("int16").tofile(os.path.join(cache_dir_path, cached_bin_fname_raw)) 
243
244        print(f"Computing {NUM_FREQ} Gabor wavelet convolutions for channel {channel}.")
245        # Loop through each frequency, convolve with Gabor wavelet
246        for f_idx,freq in tqdm(enumerate(FREQS)):
247            # Define cache filepath for this channel & freq
248            cached_binary_amp_path,cached_binary_phase_path = utils.fmt_binary_cache_wav_path(
249                    cache_dir_path,
250                    basename,
251                    channel,
252                    "AMP-PHASE",
253                    f_idx)
254            
255            # Convolve signal with the the wavelet, similar to awt_freqs
256            wt = compute_wavelet_gabor(signal=sig,fs=FS,freqs=[freq])
257
258            # TODO: do we need this condition? Won't we always use wavelet power?
259            #       if we don't use wavelet power it might break some of the code
260            # Conditionally Zscore, re-scale, and save the power
261            if "wavelet_power" in TS_FEATURES:
262                wt_power = np.abs(wt) # deep copy
263                if ZSCORE_POWER==True:
264                    wt_power = zscore(np.abs(wt))
265                # Comment in Options.toml: SCALE_POWER should be smaller if no zscore
266                logger.debug(f"wt_power.dtype={wt_power.dtype}")
267                wt_power = (wt_power * SCALE_POWER).astype("int16")
268                wt_power.tofile(cached_binary_amp_path)
269
270            # Conditionally re-scale and save the phase
271            if "wavelet_phase" in TS_FEATURES:
272                wt_phase = (np.arctan(np.real(wt) / np.imag(wt)) * SCALE_PHASE).astype("int16")
273                wt_phase.tofile(cached_binary_phase_path) # , format="int16")
274        logger.info("Finished computing wavelet transforms for channel={channel}.")
275
276        # Check all the dat files in the cache match with the regex—–
277        # this is a test to make sure our cached folder is not corrupted
278        # Filenames must not contain any dots!
279        # Regexp must match all the Amplicude (=Power) and Phase transforms
280        # as well as the raw .dat binary. 
281        regex = f"^{basename}_ch_{str(channel).zfill(3)}_(freqidx_\d\d|0RAW)(_A|_P|)$"
282        _assert_all_ext_type_match_regexp( 
283                directory = cache_dir_path,
284                extension = "dat",
285                regexp_base = regex)
286
287        ### Merge all the cached frequency .dat files into a single one.
288        # Sort the cached binaries, this will put lower f_idx first and 
289        # alternate A before P, exact same order they are created
290        # All of the files in the cache have the same channel number (checked above)
291        sorted_cache_binaries = [i for i in os.listdir(cache_dir_path) if i[-4:]==".dat"]
292        sorted_cache_binaries.sort() # The are already be sorted, this line
293                                     # is mainly to aid the reader
294        # Sanity check assert
295        assert "RAW" in sorted_cache_binaries[0], "Error in sorting the binaries, first file must be the raw channel"
296        scb_paths = [os.path.join(cache_dir_path,i) for i in sorted_cache_binaries]
297        merge_dats(
298                fpaths_in = scb_paths,
299                dir_out = WAVELET_BINARIES_PATH,
300                fname_out = f"{basename}_ch_{str(channel).zfill(3)}.dat")
301
302        # Delete the cached single-channel binary files
303        shutil.rmtree(cache_dir_path) 
304    return

Computes and saves a wavelet decomposition of each channel.

Uses dictionaries loaded from user defined options from Options.toml (options_filepath) file to compute the Gabor wavelet decomposition of the raw signals in the provided edf file (edf_fname). This function doesn't return anything, but reads and writes to disk.

The signals are scaled before saving to hard disk, this is to mitigate quantization effects, since we are saving our data as int16.

  • Reads edf raw signal specified by edf_fname (and fio_ops params)
  • Iterates through each channel, computing wavelet convolutions for frequencies in a range specified by data_ops
  • Saves output binaries, one binary file for each hardware channel, all the frequencies are saved according to the below order

Binaries array flattening convention:

  • Read 'sn' as 'sample number n'
  • A is for Amplitude (=Power), and P is for Phase
  • K is the index of the last frequency (= num of freqs - 1) [raw_s0,freq00_A_s0,freq00_P_s0,freq01_A_s0,freq01_P_s0,...,freqk_A_s0, freqK_P_s0,raw_s1,freq00_A_s1,freq00_P_s1,...,freqK_A_s1,freqK_P_s1,... ... raw_sn,freq00_A_sn,freq00_P_sn,freq01_A_s0,...,freqK_P_sn]

Note: it is important the above convention is respected because this is how the binary_io tools read the files. It's the same convention as the MatLab suit.

Parameters

edf_fname The name of the '.edf' raw data file. We look for all edf files in fio_ops.RAW_DATA_PATH from Options.toml

fio_ops : dict The fio parameters defined in the Options.toml config file.

data_ops : dict Data parameters from the Options.toml config file. The Options.toml config file contains user defined parameters, including fio params, data and data-processing params, ML model hyper-params, and training params.

def make_wavelet_bank_all(options_path='Options.toml')
307def make_wavelet_bank_all(options_path="Options.toml"):
308    # Unpack parameters and user-defined constants
309    fio_ops,data_ops = load_fio_ops_and_data_ops(options_path)
310    # Bind the consts we use to local vars for readability
311    RAW_DATA_PATH = fio_ops["RAW_DATA_PATH"]
312
313    edf_files = [i for i in os.listdir(RAW_DATA_PATH) if os.path.splitext(i)[1]==".edf"] 
314    print(f"Make wavelet bank all, convolving {len(edf_files)} edf files.")
315    for edf_fname in edf_files:
316        logger.info(f"Running make_wavelet_bank on {edf_fname}")
317        # Convolve with wavelets and write binary files
318        make_wavelet_bank(edf_fname, fio_ops, data_ops)
319
320    return