ihkapy.sm_calc_features

  1import numpy as np
  2import os
  3from ihkapy.fileio import utils
  4from ihkapy.fileio.utils import get_all_valid_session_basenames,check_session_basenames_are_valid
  5from ihkapy.fileio.binary_io import get_n_samples_from_dur_fs,load_binary_multiple_segments
  6from ihkapy.fileio.metadata_io import get_seizure_start_end_times
  7from ihkapy.fileio.options_io import load_ops_as_dict # For loading Options.toml
  8from scipy.signal import coherence
  9from ihkapy import sm_features # for featurizing
 10import warnings
 11import pandas as pd
 12from tqdm import tqdm
 13
 14
 15# pure function (not in the strict sense, but doesn't r/w)
 16def _get_x_pct_time_of_interval(
 17        start_time : float,     # in seconds
 18        end_time : float,       # in seconds
 19        segment_length : float,# in seconds
 20        pct : float             # proportion of times to sample
 21        ) -> np.ndarray:
 22    """Get a randomly sampled 1d array of start times
 23
 24    Parameters
 25    ----------
 26    `start_time : float`
 27        The beginning timestamp, in seconds, of the interval we sample from.
 28
 29    `end_time : float`
 30        The end timestamp, in seconds, of the interval we sample from.
 31
 32    `segment_length : float`
 33        The time, in seconds of our samples.
 34
 35    `pct : float`
 36        The proportion of segments to select, must be between 0 and 1. 
 37
 38    Returns
 39    -------
 40    `np.ndarray`
 41        A 1d numpy array of start times for the segments (in seconds). 
 42        Together with the segment length, these fully define the segments
 43        of interest to us that we would like to sample from. 
 44    """
 45    assert end_time > start_time
 46    assert pct >= 0.0 and pct <= 1.0
 47    # The number of segments that fit in the interval
 48    n_segments = int((end_time - start_time) // segment_length) 
 49    # List of all start times of max number of non-overlapping segments
 50    # that fit in the interval.
 51    segment_start_times = np.linspace(start_time,end_time-segment_length,n_segments)
 52    if pct == 1.0: return segment_start_times
 53    # Choose and sort a random sample according to pct
 54    n_select = int(np.ceil(n_segments * pct)) # num win to select
 55    segment_start_times = np.random.choice(segment_start_times,n_select,replace=False)
 56    segment_start_times.sort()
 57    return segment_start_times
 58
 59
 60# Helper for calc_features()
 61# pure
 62def _get_bin_absolute_intervals(
 63        start_seiz              : float,
 64        end_seiz                : float,
 65        preictal_bins           : list,
 66        postictal_delay         : int,
 67        bin_names               : list,
 68        all_start_times         : list,
 69        all_end_times           : list,
 70        total_session_time      : float
 71        ):
 72    """Returns dictionary of all valid bin intervals for single seizure.
 73
 74    The function is named 'absolute' because it measures the absolute 
 75    number of seconds from the begining of the session (aka recording/file)
 76
 77
 78    Parameters
 79    ----------
 80    `start_seiz : float or int`
 81        The start time of the seizure in question, in seconds. (Relative
 82        to the start of the session.)
 83
 84    `end_seiz : float or int`
 85        The end time of the seizure in questionm, in seconds. (Relative
 86        to the start of the session.)
 87
 88    `preictal_bins : list`
 89        The list found in Options.toml config file. It's a list of floats
 90        that specify, in seconds, the pre-ictal bin timestamps.
 91
 92    `postictal_delay : int or float`
 93        Specifies post-ictal period after the end of a seizure in seconds
 94        (e.g. 600)
 95
 96    `bin_names : list`
 97        A list of bin names (strings). Must have length of preictal_bins + 2. 
 98        This is because we have one single intra-ictal and one single post-
 99        ictal bin. 
100
101    `all_start_times : list`
102        A list of the start times of every seizure (in seconds from start 
103        of session). Needed to check validity of intervals.
104
105    `all_end_times : list `
106        A list of all end times of every seizure (in seconds from start of
107        session). Needed to check validity of intervals. 
108
109    `total_session_time : float`
110        Total number of seconds from start to end of the sessions.
111    
112    Returns
113    -------
114    `dict`
115        The keys are names of time bins (e.g. pre_ictal_bin1, post_ictal, etc.)
116        If the corresponding bin for the seizure is valid, the value will be a 
117        tuple (bin_start_time,bin_end_time). If the bin is not valid, the value
118        will be None. A siezure's time bin is considered valid if all of these
119        conditions are satisfied:
120            1. The bin does not over-lap with any other seizure.
121            2. The bin starts after the beginning of the session. 
122            3. The bin ends before the end of the session. 
123    """
124    # Some tests and checks
125    assert len(preictal_bins) + 2 == len(bin_names), "Args have incompatible length."
126    assert len(all_start_times) == len(all_end_times), "Logic err, start and end times lists must match in length."
127
128    bin_name_intraictal = bin_names[-2]     # string
129    bin_name_postictal = bin_names[-1]      # string
130    bin_names_preictal = bin_names[:-2]     # list of strings
131    
132    # Init bin_intervals (the dict we will return)
133    # NB: The seizure interval is always valid
134    bin_intervals = {bin_name_intraictal : (start_seiz,end_seiz)} 
135
136    ### PRE-ICTAL
137    # The absolute time of pre-ictal intervals, (we will iterate through them)
138    pre_bins_abs = [start_seiz - i for i in preictal_bins] 
139    # Important, not to forget this final implied pre-ictal slice 
140    # (the start of seizure)
141    pre_bins_abs.append(start_seiz) 
142    for bin_name_preictal,start_bin,end_bin in zip(bin_names_preictal,pre_bins_abs[:-1],pre_bins_abs[1:]):
143        # TODO: compare with MatLab
144        #       Notably verify if there are missing checks
145
146        bin_intervals[bin_name_preictal] = (start_bin,end_bin)
147        # Check whether another (different) seizure overlaps with our bin
148        # if so, define it as none
149        for start_seiz_2,end_seiz_2 in zip(all_start_times,all_end_times):
150            # TODO: check MatLab confirm that postictal delay term necessary
151            condition0 = start_bin < end_seiz_2 + postictal_delay 
152            condition1 = start_seiz > start_seiz_2
153            condition2 = start_bin < 0
154            if (condition0 and condition1) or condition2:
155                # The pre-ictal bin interval is not valid
156                # Re-define it as None and break out of for-loop
157                bin_intervals[bin_name_preictal] = None
158                break 
159
160    ### POST-ICTAL
161    # If the post-ictal interval is valid, define it.
162    # Post-ictal is valid iff that there is no second seizure right 
163    # after the one in question. 
164    end_postictal = end_seiz + postictal_delay
165    bin_intervals[bin_name_postictal] = (end_seiz , end_postictal)
166    # Check if invalid: redefine to None
167    if end_postictal > total_session_time: 
168        bin_intervals[bin_name_postictal] = None
169    for start_seiz_2 in all_start_times:
170        if end_postictal > start_seiz_2 and end_seiz < start_seiz_2:
171            bin_intervals[bin_name_postictal] = None
172    return bin_intervals
173
174
175def _get_total_session_time_from_binary(
176        binary_filepaths:list,
177        fs=2000,
178        n_chan_binary=41,
179        precision="int16"
180        ) -> float:
181    """Determines the total session time from binary files."""
182    fsize_bytes = os.path.getsize(binary_filepaths[0])
183    # Check they are all exactly the same size, they must be
184    for i in binary_filepaths[1:]: 
185        assert fsize_bytes == os.path.getsize(i), "Corrupt data, Raw channel binaries mismatched filesize"
186    bytes_per_sample = np.dtype(precision).itemsize
187    n_samples_per_chan = fsize_bytes / bytes_per_sample / n_chan_binary
188    assert n_samples_per_chan == int(n_samples_per_chan), "Logic error, possibly n_chan_binary is incorrect: this the number of channels saved in the binary file."
189    total_duration_in_seconds = n_samples_per_chan / fs # total duration in seconds
190    return total_duration_in_seconds
191
192
193def _get_feats_df_column_names(
194        features_list   : list,
195        data_ops        : dict,
196        ):
197    """Calls the featurizing method on dummy segments, forwards columnames returned."""
198    N_CHAN_RAW      = data_ops["N_CHAN_RAW"]
199    N_CHAN_BINARY   = data_ops["N_CHAN_BINARY"]
200    # Rem: session_basename identifies the recording, time_bin is the class label
201    # Dummy segments random noise not cnst or div by zero error in coherence (psd=0)
202    dummy_segments = np.random.normal(0,1,(N_CHAN_RAW,10000,N_CHAN_BINARY)) # 10000 samples is a 5s segment at 2000Hz
203    dummy_features = sm_features.get_feats(dummy_segments,features_list,data_ops)
204    colnames = ["session_basename","time_bin"] # first two colnames by default
205    colnames += [k for k in dummy_features.keys()] # concatenate
206    return colnames
207
208def _get_match_basename_in_dir_paths(directory,basename):
209    """Returns list of all full paths of files starting with b in basenames, in directory."""
210    return [os.path.join(directory,i) for i in os.listdir(directory) if i[:len(basename)]==basename]
211
212
213# Calculates all the features for one session and returns them in a 
214# pandas dataframe
215def calc_features(
216        fio_ops : dict,
217        data_ops : dict,
218        feature_ops : dict,
219        session_basename : str,
220        ) -> pd.DataFrame:
221    """Compute all features in all channels for multiple sessions.
222
223    Several noteworthy assumptions are made: It is assumed that the
224    binary data file and the metadata text file share the same basename.
225    (Asside: metadata txt file is parsed by parse_metadata in metadata_io.py)
226    
227    Parameters
228    ----------
229    `session_basenames_list : list or str`
230        If the list is empty, it will default to all valid edf files in the 
231        directory. 
232
233    Returns
234    -------
235    `dict`
236        The keys are the bin names to which the features belong, and the
237        values are pandas dataframes with columns { sessions_basename , 
238        start_time , feat_01_name , feat_02_name , ... }
239        
240    `dict`
241        Options that where used to generate the data. This could be a 
242        serialized copy of the Options.toml file. 
243    """
244    # feature_functions = # TODO: this is a list of functions
245    # features = # TODO: 2d np array shape = (n_segments,n_feats_per_segment)
246
247    # Unpack fio_ops params
248    RAW_DATA_PATH         = fio_ops["RAW_DATA_PATH"]
249    WAVELET_BINARIES_PATH = fio_ops["WAVELET_BINARIES_PATH"]
250
251    # Unpack data params
252    SCALE_PHASE     = data_ops["SCALE_PHASE"]
253    SCALE_POWER     = data_ops["SCALE_POWER"]
254    FS              = data_ops["FS"]
255    N_CHAN_RAW      = data_ops["N_CHAN_RAW"]
256    N_CHAN_BINARY   = data_ops["N_CHAN_BINARY"]
257    PRECISION       = data_ops["PRECISION"]
258    # TODO: clean up amp and phase indexing conventions
259    amp_idx         = data_ops["AMP_IDX"]
260    AMP_IDX = np.array(amp_idx) * 2 + 1 
261    ph_idx          = data_ops["PH_IDX"]
262    PH_IDX  = np.array(ph_idx)  * 2 + 2 
263 
264    # Unpack relevant parameters, params.feature in Options.toml
265    BIN_NAMES       = feature_ops["BIN_NAMES"]
266    PREICTAL_BINS   = feature_ops["PREICTAL"]["BINS"] # in (negative) seconds
267    POSTICTAL_DELAY = feature_ops["POSTICTAL"]["DELAY"] # in seconds
268    PREICTAL_PCT    = feature_ops["PREICTAL"]["PCT"]
269    INTRAICTAL_PCT  = feature_ops["INTRAICTAL"]["PCT"]
270    POSTICTAL_PCT   = feature_ops["POSTICTAL"]["PCT"]
271    PCT     = PREICTAL_PCT + [INTRAICTAL_PCT] + [POSTICTAL_PCT] # Concatenation
272    PCT_DIC = {b:pct for b,pct in zip(BIN_NAMES,PCT)} # a handy pct dictionary 
273    DUR_FEAT        = feature_ops["DUR_FEAT"]
274    N_SAMPLES       = get_n_samples_from_dur_fs(DUR_FEAT,FS) # n samples per feature
275    FEATURES        = feature_ops["FEATURES"]
276
277    # Init Pandas DataFrame with right colnames
278    colnames = _get_feats_df_column_names(FEATURES,data_ops)
279    feats_df = pd.DataFrame({name:[] for name in colnames})
280
281    # Retrieve seizure times from metadata
282    session_metadata_path = os.path.join(RAW_DATA_PATH, session_basename+".txt")
283    start_times,end_times = get_seizure_start_end_times(session_metadata_path)
284    # All of the session binaries
285    session_binary_paths = _get_match_basename_in_dir_paths(WAVELET_BINARIES_PATH,session_basename)
286    # Get total session time, assumes all binary's exact same length
287    total_session_time = _get_total_session_time_from_binary(
288            session_binary_paths, 
289            fs=FS,
290            n_chan_binary=N_CHAN_BINARY,
291            precision=PRECISION) # in secs
292
293    # TODO: The following three nested for loops are potentially hard to
294    # swallow, consider refactoring them. Q: is this necessary?
295
296    # For each seizure in the session
297    print(f"Computing features for session {session_basename}")
298    for start_time,end_time in tqdm(list(zip(start_times,end_times))):
299        # _get_bin_absolute_intervals() returns a dic with 
300        # keys=BIN_NAMES,values=(strtbin,endbin) (in s). 
301        # If the intervals of a bin are not valid, either because 
302        # they start before or end after a file or because they 
303        # overlap with another seizure: the value at bin_name 
304        # is None.
305        bins_absolute_intervals = _get_bin_absolute_intervals(
306                start_seiz          = start_time,
307                end_seiz            = end_time,
308                preictal_bins       = PREICTAL_BINS, 
309                postictal_delay     = POSTICTAL_DELAY,
310                bin_names           = BIN_NAMES,
311                all_start_times     = start_times,
312                all_end_times       = end_times,
313                total_session_time  = total_session_time
314                )
315
316        # For each time bin (=interval label) corresponding to this session
317        for bin_name in BIN_NAMES:
318            interval = bins_absolute_intervals[bin_name] # interval in seconds
319            pct      = PCT_DIC[bin_name] # % of intervals to grab, float in (0,1]
320            # If the interval is valid, get the segment start-times
321            if interval:
322                start_time,end_time = interval # unpack interval tuple
323                # Get set of timestamps corresponding to start of segments
324                segment_starts = _get_x_pct_time_of_interval(
325                        start_time      = start_time,
326                        end_time        = end_time,
327                        segment_length   = DUR_FEAT,
328                        pct             = pct
329                        )
330                # This holds the segments in this time-bin
331                bin_segments = np.zeros((
332                    N_CHAN_RAW,
333                    len(segment_starts),
334                    N_SAMPLES, 
335                    N_CHAN_BINARY
336                    ))
337                for raw_ch_idx in range(N_CHAN_RAW):
338                    # Is it dangerous to use such a strict file-naming convention?
339                    # Yes, TODO: refactor all nameings of things that derive 
340                    # from basenames to utility method
341                    session_binary_chan_raw_path = utils.fmt_binary_chan_raw_path(
342                            WAVELET_BINARIES_PATH,
343                            session_basename,
344                            raw_ch_idx
345                            )
346                    # Load all segments from specific time bin and raw chan
347                    ws = load_binary_multiple_segments(
348                            file_path       = session_binary_chan_raw_path,
349                            n_chan          = N_CHAN_BINARY,
350                            sample_rate     = FS,
351                            offset_times    = segment_starts,
352                            duration_time   = DUR_FEAT,
353                            precision       = PRECISION
354                            )
355                    assert ws.shape == (len(segment_starts),N_SAMPLES,N_CHAN_BINARY)
356                    bin_segments[raw_ch_idx,:,:,:] = ws
357
358                # Get features for segment 
359                for segment in bin_segments.transpose((1,0,2,3)):
360                    # This a single segment, all channels, raw and wavelet/binary
361                    assert segment.shape == (N_CHAN_RAW,N_SAMPLES,N_CHAN_BINARY) 
362                    feats = sm_features.get_feats(segment,FEATURES,data_ops)
363                    # Add the session name and time bin to the row dictionary
364                    feats.update({
365                        "session_basename"  : session_basename,
366                        "time_bin"          : bin_name
367                        })
368                    # Add the row to our pandas dataframe
369                    feats_df.loc[len(feats_df.index)] = feats
370    return feats_df
371
372def calc_features_all(options_path="Options.toml"):
373    """Computes features csv for each binary file in the raw binary data directory."""
374    # Unpack parameters and user-defined constants
375    ops = load_ops_as_dict(options_path=options_path)
376    data_ops = ops["params"]["data"]
377    fio_ops = ops["fio"]
378    feature_ops = ops["params"]["feature"]
379    RAW_DATA_PATH = fio_ops["RAW_DATA_PATH"]
380    FEATURES_PATH = fio_ops["FEATURES_PATH"]
381    valid_basenames = get_all_valid_session_basenames(RAW_DATA_PATH)
382    print(f"Computing features for {len(valid_basenames)} sessions...")
383    for basename in valid_basenames:
384        # Get the features of one session
385        df = calc_features(fio_ops,data_ops,feature_ops,session_basename=basename)
386        # Save the them
387        csv_out_path = utils.fmt_features_df_csv_path(FEATURES_PATH,basename)
388        df.to_csv(csv_out_path)
389        # TODO: delete following line, above two lines replaced it
390        # df.to_csv(os.path.join(FEATURES_PATH,f"{basename}.csv"))
391    return         
392
393
394if __name__=="__main__":
395    import array
396    ### UNIT TESTS ###
397
398    ### TEST _get_x_pct_time_of_interval()
399    arr = _get_x_pct_time_of_interval(
400            start_time    = 5.0,
401            end_time      = 152.6,
402            segment_length = 1.0,
403            pct           = 0.05)
404    print("Test _get_x_pct_time_of_interval(), array returned:")
405    print(arr)
406    print()
407
408
409    ### TEST _get_bin_absolute_intervals()  
410    # Test _get_bin_absolute_intervals() 1 
411    bin_abs_intervals = _get_bin_absolute_intervals(
412            start_seiz           = 100,
413            end_seiz             = 130,
414            preictal_bins        = [50 , 25 , 10],
415            postictal_delay      = 60,
416            bin_names            = ["pre1","pre2","pre3","intra","post"],
417            all_start_times      = [100,500,1000],
418            all_end_times        = [130,550,1100],
419            total_session_time = 2000
420            )
421    assert bin_abs_intervals["pre1"] == (50,75)
422    assert bin_abs_intervals["pre2"] == (75,90)
423    assert bin_abs_intervals["pre3"] == (90,100)
424    assert bin_abs_intervals["intra"] == (100,130)
425    assert bin_abs_intervals["post"] == (130,190)
426
427    # Test _get_bin_absolute_intervals() 2 
428    bin_abs_intervals = _get_bin_absolute_intervals(
429            start_seiz           = 100,
430            end_seiz             = 130,
431            preictal_bins        = [50 , 25 , 10],
432            postictal_delay      = 10,
433            bin_names            = ["pre1","pre2","pre3","intra","post"],
434            all_start_times      = [50,100,135],
435            all_end_times        = [60,130,170],
436            total_session_time = 2000 
437            )
438    assert bin_abs_intervals["pre1"] == None    # Overlaps with previous post-ictal
439    assert bin_abs_intervals["pre2"] == (75,90)
440    assert bin_abs_intervals["pre3"] == (90,100)
441    assert bin_abs_intervals["intra"] == (100,130)
442    assert bin_abs_intervals["post"] == None        # Overlaps with next seizure
443
444    # Test _get_bin_absolute_intervals() 3
445    bin_abs_intervals = _get_bin_absolute_intervals(
446            start_seiz           = 15,
447            end_seiz             = 100,
448            preictal_bins        = [50 , 25 , 10],
449            postictal_delay      = 60,
450            bin_names            = ["pre1","pre2","pre3","intra","post"],
451            all_start_times      = [15],
452            all_end_times        = [100],
453            total_session_time = 150 
454            )
455    assert bin_abs_intervals["pre1"] == None        # Before file start
456    assert bin_abs_intervals["pre2"] == None        # Before file start
457    assert bin_abs_intervals["pre3"] == (5,15)      # Valid
458    assert bin_abs_intervals["intra"] == (15,100)   # Valid
459    assert bin_abs_intervals["post"] == None        # Ends after end of file
460    # Not every single edge-case is tested... (low priority TODO)
461
462    ### TEST _get_total_session_time()
463    # Generate array, serialize it, then measure it's length
464    arr = array.array("h", np.arange(256)) # highly divisible by two test array
465    with open("temp_test.dat","wb") as file:
466        arr.tofile(file)
467    fs,n_chan_binary = 2,4
468    total_duration_in_seconds = _get_total_session_time_from_binary(
469        binary_filepaths = ["./temp_test.dat"],
470        fs=fs,
471        n_chan_binary=n_chan_binary,
472        precision="int16"
473        )
474    os.remove("temp_test.dat") # delete the test file
475    try: 
476        assert total_duration_in_seconds == len(arr) // fs // n_chan_binary
477        print("Passed Test _get_total_session_time()")
478    except: 
479        print("Failed Test _get_total_session_time()")
480
481    ### TEST _get_feats_df_column_names(), and sm_features.get_feats() implicitly
482    features = ["mean_power","coherence"]
483    data_ops = {"FS":2000,"NUM_FREQ":20,"LOW_FREQ":0.5,"HIGH_FREQ":200,
484            "SPACING":"LOGARITHMIC","ZSCORE_POWER":True,"SCALE_PHASE":1000,
485            "SCALE_POWER":1000,"N_CHAN_RAW":4,"CH_PHASE_AMP":2,
486            "TS_FEATURES":["WAVELET_POWER","WAVELET_PHASE"],
487            "N_CHAN_BINARY":3,"PRECISION":"int16",
488            "AMP_IDX":[],#[14,15,16,17,18,19],
489            "PH_IDX":[0],#[0,1,2,3,4,5,6,7,8,9],
490            "AMP_FREQ_IDX_ALL":[1],#,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39],
491            "PH_FREQ_IDX_ALL":[2]}#,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40]}
492    colnames = _get_feats_df_column_names(features,data_ops)
493    print("\nTest, first eight column names are:")
494    for i in colnames[:8]: print(i)
495
496    ### TEST calc_features
497    fio_ops = {"RAW_DATA_PATH":"/Users/steve/Documents/code/unm/data/h24_data/raw",
498            "WAVELET_BINARIES_PATH":"/Users/steve/Documents/code/unm/data/h24_data/binaries"}
499    # data_ops, use same dict as in above sm_features.get_feats() test, 
500    # but update it to include all amplitude and phase indices
501    data_ops["N_CHAN_BINARY"] = 41
502    data_ops["AMP_IDX"] = [14,15,16,17,18,19]
503    data_ops["PH_IDX"] = [0,1,2,3,4,5,6,7,8,9]
504    data_ops["AMP_FREQ_IDX_ALL"] = [1,3,5,7,9,11,13,15,17,19,21,23,25,27,29,31,33,35,37,39]
505    data_ops["PH_FREQ_IDX_ALL"] = [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40]
506    feature_ops = {"DUR_FEAT":5,"N_PREICTAL_BINS":4,
507            "PREICTAL":{"BINS":[10800,3600,600,10],"PCT":[0.05,0.05,0.2,1.0]},
508            "INTRAICTAL":{"PCT":1.0},
509            "POSTICTAL":{"DELAY":600,"PCT":0.2},
510            "BIN_NAMES":["pre1","pre2","pre3",
511                "pre4","intra","post"],
512            "FEATURES":["mean_power","var","coherence"]}
513    session_basename = "AC75a-5 DOB 072519_TS_2020-03-23_17_30_04"
514    feats_df = calc_features(fio_ops,data_ops,feature_ops,session_basename)
515    # serialize feat dataframe to look at it in jupyter notebook
516    print("Writing to CSV 'feats_df.csv'")
517    feats_df.to_csv("feats_df.csv")
518    # saveas = "feats_df.pkl"
519    # print(f"Pickling features {saveas}")
520    # feats_df.to_pickle(saveas)
521    print("Passed test calc_features()")
522
523    print("\nTests All Passed: _get_bin_absolute_intervals()")
def calc_features( fio_ops: dict, data_ops: dict, feature_ops: dict, session_basename: str) -> pandas.core.frame.DataFrame:
216def calc_features(
217        fio_ops : dict,
218        data_ops : dict,
219        feature_ops : dict,
220        session_basename : str,
221        ) -> pd.DataFrame:
222    """Compute all features in all channels for multiple sessions.
223
224    Several noteworthy assumptions are made: It is assumed that the
225    binary data file and the metadata text file share the same basename.
226    (Asside: metadata txt file is parsed by parse_metadata in metadata_io.py)
227    
228    Parameters
229    ----------
230    `session_basenames_list : list or str`
231        If the list is empty, it will default to all valid edf files in the 
232        directory. 
233
234    Returns
235    -------
236    `dict`
237        The keys are the bin names to which the features belong, and the
238        values are pandas dataframes with columns { sessions_basename , 
239        start_time , feat_01_name , feat_02_name , ... }
240        
241    `dict`
242        Options that where used to generate the data. This could be a 
243        serialized copy of the Options.toml file. 
244    """
245    # feature_functions = # TODO: this is a list of functions
246    # features = # TODO: 2d np array shape = (n_segments,n_feats_per_segment)
247
248    # Unpack fio_ops params
249    RAW_DATA_PATH         = fio_ops["RAW_DATA_PATH"]
250    WAVELET_BINARIES_PATH = fio_ops["WAVELET_BINARIES_PATH"]
251
252    # Unpack data params
253    SCALE_PHASE     = data_ops["SCALE_PHASE"]
254    SCALE_POWER     = data_ops["SCALE_POWER"]
255    FS              = data_ops["FS"]
256    N_CHAN_RAW      = data_ops["N_CHAN_RAW"]
257    N_CHAN_BINARY   = data_ops["N_CHAN_BINARY"]
258    PRECISION       = data_ops["PRECISION"]
259    # TODO: clean up amp and phase indexing conventions
260    amp_idx         = data_ops["AMP_IDX"]
261    AMP_IDX = np.array(amp_idx) * 2 + 1 
262    ph_idx          = data_ops["PH_IDX"]
263    PH_IDX  = np.array(ph_idx)  * 2 + 2 
264 
265    # Unpack relevant parameters, params.feature in Options.toml
266    BIN_NAMES       = feature_ops["BIN_NAMES"]
267    PREICTAL_BINS   = feature_ops["PREICTAL"]["BINS"] # in (negative) seconds
268    POSTICTAL_DELAY = feature_ops["POSTICTAL"]["DELAY"] # in seconds
269    PREICTAL_PCT    = feature_ops["PREICTAL"]["PCT"]
270    INTRAICTAL_PCT  = feature_ops["INTRAICTAL"]["PCT"]
271    POSTICTAL_PCT   = feature_ops["POSTICTAL"]["PCT"]
272    PCT     = PREICTAL_PCT + [INTRAICTAL_PCT] + [POSTICTAL_PCT] # Concatenation
273    PCT_DIC = {b:pct for b,pct in zip(BIN_NAMES,PCT)} # a handy pct dictionary 
274    DUR_FEAT        = feature_ops["DUR_FEAT"]
275    N_SAMPLES       = get_n_samples_from_dur_fs(DUR_FEAT,FS) # n samples per feature
276    FEATURES        = feature_ops["FEATURES"]
277
278    # Init Pandas DataFrame with right colnames
279    colnames = _get_feats_df_column_names(FEATURES,data_ops)
280    feats_df = pd.DataFrame({name:[] for name in colnames})
281
282    # Retrieve seizure times from metadata
283    session_metadata_path = os.path.join(RAW_DATA_PATH, session_basename+".txt")
284    start_times,end_times = get_seizure_start_end_times(session_metadata_path)
285    # All of the session binaries
286    session_binary_paths = _get_match_basename_in_dir_paths(WAVELET_BINARIES_PATH,session_basename)
287    # Get total session time, assumes all binary's exact same length
288    total_session_time = _get_total_session_time_from_binary(
289            session_binary_paths, 
290            fs=FS,
291            n_chan_binary=N_CHAN_BINARY,
292            precision=PRECISION) # in secs
293
294    # TODO: The following three nested for loops are potentially hard to
295    # swallow, consider refactoring them. Q: is this necessary?
296
297    # For each seizure in the session
298    print(f"Computing features for session {session_basename}")
299    for start_time,end_time in tqdm(list(zip(start_times,end_times))):
300        # _get_bin_absolute_intervals() returns a dic with 
301        # keys=BIN_NAMES,values=(strtbin,endbin) (in s). 
302        # If the intervals of a bin are not valid, either because 
303        # they start before or end after a file or because they 
304        # overlap with another seizure: the value at bin_name 
305        # is None.
306        bins_absolute_intervals = _get_bin_absolute_intervals(
307                start_seiz          = start_time,
308                end_seiz            = end_time,
309                preictal_bins       = PREICTAL_BINS, 
310                postictal_delay     = POSTICTAL_DELAY,
311                bin_names           = BIN_NAMES,
312                all_start_times     = start_times,
313                all_end_times       = end_times,
314                total_session_time  = total_session_time
315                )
316
317        # For each time bin (=interval label) corresponding to this session
318        for bin_name in BIN_NAMES:
319            interval = bins_absolute_intervals[bin_name] # interval in seconds
320            pct      = PCT_DIC[bin_name] # % of intervals to grab, float in (0,1]
321            # If the interval is valid, get the segment start-times
322            if interval:
323                start_time,end_time = interval # unpack interval tuple
324                # Get set of timestamps corresponding to start of segments
325                segment_starts = _get_x_pct_time_of_interval(
326                        start_time      = start_time,
327                        end_time        = end_time,
328                        segment_length   = DUR_FEAT,
329                        pct             = pct
330                        )
331                # This holds the segments in this time-bin
332                bin_segments = np.zeros((
333                    N_CHAN_RAW,
334                    len(segment_starts),
335                    N_SAMPLES, 
336                    N_CHAN_BINARY
337                    ))
338                for raw_ch_idx in range(N_CHAN_RAW):
339                    # Is it dangerous to use such a strict file-naming convention?
340                    # Yes, TODO: refactor all nameings of things that derive 
341                    # from basenames to utility method
342                    session_binary_chan_raw_path = utils.fmt_binary_chan_raw_path(
343                            WAVELET_BINARIES_PATH,
344                            session_basename,
345                            raw_ch_idx
346                            )
347                    # Load all segments from specific time bin and raw chan
348                    ws = load_binary_multiple_segments(
349                            file_path       = session_binary_chan_raw_path,
350                            n_chan          = N_CHAN_BINARY,
351                            sample_rate     = FS,
352                            offset_times    = segment_starts,
353                            duration_time   = DUR_FEAT,
354                            precision       = PRECISION
355                            )
356                    assert ws.shape == (len(segment_starts),N_SAMPLES,N_CHAN_BINARY)
357                    bin_segments[raw_ch_idx,:,:,:] = ws
358
359                # Get features for segment 
360                for segment in bin_segments.transpose((1,0,2,3)):
361                    # This a single segment, all channels, raw and wavelet/binary
362                    assert segment.shape == (N_CHAN_RAW,N_SAMPLES,N_CHAN_BINARY) 
363                    feats = sm_features.get_feats(segment,FEATURES,data_ops)
364                    # Add the session name and time bin to the row dictionary
365                    feats.update({
366                        "session_basename"  : session_basename,
367                        "time_bin"          : bin_name
368                        })
369                    # Add the row to our pandas dataframe
370                    feats_df.loc[len(feats_df.index)] = feats
371    return feats_df

Compute all features in all channels for multiple sessions.

Several noteworthy assumptions are made: It is assumed that the binary data file and the metadata text file share the same basename. (Asside: metadata txt file is parsed by parse_metadata in metadata_io.py)

Parameters

session_basenames_list : list or str If the list is empty, it will default to all valid edf files in the directory.

Returns

dict The keys are the bin names to which the features belong, and the values are pandas dataframes with columns { sessions_basename , start_time , feat_01_name , feat_02_name , ... }

dict Options that where used to generate the data. This could be a serialized copy of the Options.toml file.

def calc_features_all(options_path='Options.toml')
373def calc_features_all(options_path="Options.toml"):
374    """Computes features csv for each binary file in the raw binary data directory."""
375    # Unpack parameters and user-defined constants
376    ops = load_ops_as_dict(options_path=options_path)
377    data_ops = ops["params"]["data"]
378    fio_ops = ops["fio"]
379    feature_ops = ops["params"]["feature"]
380    RAW_DATA_PATH = fio_ops["RAW_DATA_PATH"]
381    FEATURES_PATH = fio_ops["FEATURES_PATH"]
382    valid_basenames = get_all_valid_session_basenames(RAW_DATA_PATH)
383    print(f"Computing features for {len(valid_basenames)} sessions...")
384    for basename in valid_basenames:
385        # Get the features of one session
386        df = calc_features(fio_ops,data_ops,feature_ops,session_basename=basename)
387        # Save the them
388        csv_out_path = utils.fmt_features_df_csv_path(FEATURES_PATH,basename)
389        df.to_csv(csv_out_path)
390        # TODO: delete following line, above two lines replaced it
391        # df.to_csv(os.path.join(FEATURES_PATH,f"{basename}.csv"))
392    return         

Computes features csv for each binary file in the raw binary data directory.