diff --git a/.gitignore b/.gitignore index bfc40973..79c11978 100644 --- a/.gitignore +++ b/.gitignore @@ -45,4 +45,20 @@ test_*.py *.csv .vscode/* test_delsys_api.py -resources/ \ No newline at end of file +resources/ +*.csv +*.txt +ContinuousTransitions/* +FORS-EMG/* +MyoDisCo/* +NinaProDB1/* +*.zip +libemg/_datasets/__pycache__/* +CIILData/* +EMGEPN612.pkl +OneSubjectMyoDataset/ +_3DCDataset/ +ContractionIntensity/ +CIILData/ +*.pkl +LimbPosition/ \ No newline at end of file diff --git a/dataset_tryout.py b/dataset_tryout.py new file mode 100644 index 00000000..908a8ba1 --- /dev/null +++ b/dataset_tryout.py @@ -0,0 +1,4 @@ +from libemg.datasets import * + +accs = evaluate('LDA', 300, 100, feature_list=['MAV','SSC','ZC','WL'], included_datasets=['FougnerLP'], save_dir='') +print('\n' + str(accs)) \ No newline at end of file diff --git a/docs/source/examples/offline_regression_example/offline_regression.md b/docs/source/examples/offline_regression_example/offline_regression.md new file mode 100644 index 00000000..7ecb798b --- /dev/null +++ b/docs/source/examples/offline_regression_example/offline_regression.md @@ -0,0 +1,98 @@ +[View Source Code](https://github.com/LibEMG/LibEMG_OfflineRegression_Showcase) + + + +This simple offline example showcases some of the offline capabilities for regression analysis. In this example, we will load in the OneSubjectEMaGerDataset and assess the performance of multiple regressors. All code can be found in `main.py`. + +## Step 1: Importing LibEMG + +The very first step involves importing the modules needed. In general, each of LibEMG's modules has its own import. Make sure that you have successfully installed libemg through pip. + +```Python +import numpy as np +import matplotlib.pyplot as plt +from libemg.offline_metrics import OfflineMetrics +from libemg.datasets import OneSubjectEMaGerDataset +from libemg.feature_extractor import FeatureExtractor +from libemg.emg_predictor import EMGRegressor +``` + +## Step 2: Setting up Constants + +Preprocessing parameters, such as window size, window increment, and the feature set must be decided before EMG data can be prepared for estimation. LibEMG defines window and increment sizes as the number of samples. In this case, the dataset was recorded from the EMaGer cuff, which samples at 1 kHz, so a window of 150 samples corresponds to 150ms. + +The window increment, window size, and feature set default to 40, 150, and 'HTD', respecively. These variables can be customized in this script using the provided CLI. Use `python main.py -h` for an explanation of the CLI. Example usage is also provided below: + +```Bash +python main.py --window_size 200 --window_increment 50 --feature_set MSWT +``` + +# Step 3: Loading in Dataset + +This example uses the `OneSubjectEMaGerDataset`. Instantiating the `Dataset` will automatically download the data into the specified directory, and calling the `prepare_data()` method will load EMG data and metadata (e.g., reps, movements, labels) into an `OfflineDataHandler`. This dataset consists of 5 repetitions, so we use 4 for training data and 1 for testing data. After splitting our data into training and test splits, we perform windowing on the raw EMG data. By default, the metadata assigned to each window will be based on the mode of that window. Since we are analyzing regression data, we pass in a function that tells the `OfflineDataHandler` to grab the label from the last sample in the window instead of taking the mode of the window. We can specify how we want to handle windowing of each type of metadata by passing in a `metadata_operations` dictionary. + +```Python +# Load data +odh = OneSubjectEMaGerDataset().prepare_data() + +# Split into train/test reps +train_odh = odh.isolate_data('reps', [0, 1, 2, 3]) +test_odh = odh.isolate_data('reps', [4]) + +# Extract windows +metadata_operations = {'labels': lambda x: x[-1]} # grab label of last sample in window +train_windows, train_metadata = train_odh.parse_windows(args.window_size, args.window_increment, metadata_operations=metadata_operations) +test_windows, test_metadata = test_odh.parse_windows(args.window_size, args.window_increment, metadata_operations=metadata_operations) +``` + +# Step 4: Feature Extraction + +We then extract features using the `FeatureExtractor` for our training and test data. The `fit()` method expects a dictionary with the keys `training_features` and `training_labels`, so we create one and pass in our extracted features and training labels. + +```Python +training_features = fe.extract_feature_group(args.feature_set, train_windows, array=True), +training_labels = train_metadata['labels'] +test_features = fe.extract_feature_group(args.feature_set, test_windows, array=True) +test_labels = test_metadata['labels'] + +training_set = { + 'training_features': training_features, + 'training_labels': training_labels +} +``` + +# Step 5: Regression + +`LibEMG` allows you to pass in custom models, but you can also pass in a string that will create a model for you. In this example, we compare a linear regressor to a gradient boosting regressor. We iterate through a list of the models we want to observe, fit the model to the training data, and calculate metrics based on predictions on the test data. We then store these metrics for plotting later. + +```Python +results = {metric: [] for metric in ['R2', 'NRMSE', 'MAE']} +for model in models: + reg = EMGRegressor(model) + + # Fit and run model + print(f"Fitting {model}...") + reg.fit(training_set.copy()) + predictions = reg.run(test_features) + + metrics = om.extract_offline_metrics(results.keys(), test_labels, predictions) + for metric in metrics: + results[metric].append(metrics[metric].mean()) +``` + +# Step 6: Visualization + +Finally, we visualize our results. We first plot the decision stream for each model. After each model is fitted, we plot the offline metrics for each type of model. + +```Python +# Note: this will block the main thread once the plot is shown. Close the plot to continue execution. +reg.visualize(test_labels, predictions) + +fig, axs = plt.subplots(nrows=len(results), layout='constrained', figsize=(8, 8), sharex=True) +for metric, ax in zip(results.keys(), axs): + ax.bar(models, np.array(results[metric]) * 100) + ax.set_ylabel(f"{metric} (%)") + +fig.suptitle('Metrics Summary') +plt.show() +``` diff --git a/docs/source/examples/offline_regression_example/offline_regression_example.rst b/docs/source/examples/offline_regression_example/offline_regression_example.rst new file mode 100644 index 00000000..7e8707f7 --- /dev/null +++ b/docs/source/examples/offline_regression_example/offline_regression_example.rst @@ -0,0 +1,4 @@ +Offline Regression Analysis +========================================== +.. include:: offline_regression.md + :parser: myst_parser.sphinx_ \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 90ff5a3d..4e70e20a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -43,6 +43,7 @@ LibEMG examples/features_and_group_example/features_and_group_example examples/feature_optimization_example/feature_optimization_example examples/deep_learning_example/deep_learning_example + examples/offline_regression_example/offline_regression_example .. toctree:: :maxdepth: 1 diff --git a/libemg/_datasets/_3DC.py b/libemg/_datasets/_3DC.py new file mode 100644 index 00000000..daacbced --- /dev/null +++ b/libemg/_datasets/_3DC.py @@ -0,0 +1,45 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter + +class _3DCDataset(Dataset): + def __init__(self, dataset_folder="_3DCDataset/"): + Dataset.__init__(self, + 1000, + 10, + '3DC Armband (Prototype)', + 22, + {0: "Neutral", 1: "Radial Deviation", 2: "Wrist Flexion", 3: "Ulnar Deviation", 4: "Wrist Extension", 5: "Supination", 6: "Pronation", 7: "Power Grip", 8: "Open Hand", 9: "Chuck Grip", 10: "Pinch Grip"}, + '8 (4 Train, 4 Test)', + "The 3DC dataset including 11 classes.", + "https://doi.org/10.3389/fbioe.2020.00158") + self.url = "https://github.com/libemg/3DCDataset" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False, subjects_values = None, sets_values = None, reps_values = None, + classes_values = None): + if subjects_values is None: + subjects_values = [str(i) for i in range(1,23)] + if sets_values is None: + sets_values = ["train", "test"] + if reps_values is None: + reps_values = ["0","1","2","3"] + if classes_values is None: + classes_values = [str(i) for i in range(11)] + + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + regex_filters = [ + RegexFilter(left_bound = "/", right_bound="/EMG", values = sets_values, description='sets'), + RegexFilter(left_bound = "_", right_bound=".txt", values = classes_values, description='classes'), + RegexFilter(left_bound = "EMG_gesture_", right_bound="_", values = reps_values, description='reps'), + RegexFilter(left_bound="Participant", right_bound="/",values=subjects_values, description='subjects') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("sets", [0], fast=True), 'Test': odh.isolate_data("sets", [1], fast=True)} + + return data \ No newline at end of file diff --git a/libemg/_datasets/__init__.py b/libemg/_datasets/__init__.py new file mode 100644 index 00000000..892b01cf --- /dev/null +++ b/libemg/_datasets/__init__.py @@ -0,0 +1,17 @@ +from libemg._datasets import _3DC +from libemg._datasets import ciil +from libemg._datasets import continous_transitions +from libemg._datasets import dataset +from libemg._datasets import emg_epn612 +from libemg._datasets import fors_emg +from libemg._datasets import fougner_lp +from libemg._datasets import grab_myo +from libemg._datasets import hyser +from libemg._datasets import intensity +from libemg._datasets import kaufmann_md +from libemg._datasets import myodisco +from libemg._datasets import nina_pro +from libemg._datasets import one_subject_emager +from libemg._datasets import one_subject_myo +from libemg._datasets import radmand_lp +from libemg._datasets import tmr_shirleyryanabilitylab \ No newline at end of file diff --git a/libemg/_datasets/ciil.py b/libemg/_datasets/ciil.py new file mode 100644 index 00000000..8b9e7a3d --- /dev/null +++ b/libemg/_datasets/ciil.py @@ -0,0 +1,148 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter, FilePackager +from pathlib import Path + + + +class CIIL_MinimalData(Dataset): + def __init__(self, dataset_folder='CIILData/'): + Dataset.__init__(self, + 200, + 8, + 'Myo Armband', + 11, + {0: 'Close', 1: 'Open', 2: 'Rest', 3: 'Flexion', 4: 'Extension'}, + '1 Train (1s), 15 Test', + "The goal of this Myo dataset is to explore how well models perform when they have a limited amount of training data (1s per class).", + 'https://ieeexplore.ieee.org/abstract/document/10394393') + self.url = "https://github.com/LibEMG/CIILData" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + subfolder = 'MinimalTrainingData' + subjects = [str(i) for i in range(0, 11)] + classes_values = [str(i) for i in range(0,5)] + reps_values = ["0","1","2"] + sets = ["train", "test"] + regex_filters = [ + RegexFilter(left_bound = "/", right_bound="/", values = sets, description='sets'), + RegexFilter(left_bound = "/subject", right_bound="/", values = subjects, description='subjects'), + RegexFilter(left_bound = "R_", right_bound="_", values = reps_values, description='reps'), + RegexFilter(left_bound = "C_", right_bound=".csv", values = classes_values, description='classes') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder + '/' + subfolder, regex_filters=regex_filters, delimiter=",") + + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("sets", [0], fast=True), 'Test': odh.isolate_data("sets", [1], fast=True)} + + return data + +class CIIL_ElectrodeShift(Dataset): + def __init__(self, dataset_folder='CIILData/'): + Dataset.__init__(self, + 200, + 8, + 'Myo Armband', + 21, + {0: 'Close', 1: 'Open', 2: 'Rest', 3: 'Flexion', 4: 'Extension'}, + '5 Train (Before Shift), 8 Test (After Shift)', + "An electrode shift confounding factors dataset.", + 'https://link.springer.com/article/10.1186/s12984-024-01355-4') + self.url = "https://github.com/LibEMG/CIILData" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + subfolder = 'ElectrodeShift' + subjects = [str(i) for i in range(0, 21)] + classes_values = [str(i) for i in range(0,5)] + reps_values = ["0","1","2","3","4"] + sets = ["training", "trial_1", "trial_2", "trial_3", "trial_4"] + regex_filters = [ + RegexFilter(left_bound = "/", right_bound="/", values = sets, description='sets'), + RegexFilter(left_bound = "/subject", right_bound="/", values = subjects, description='subjects'), + RegexFilter(left_bound = "R_", right_bound="_", values = reps_values, description='reps'), + RegexFilter(left_bound = "C_", right_bound=".csv", values = classes_values, description='classes') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder + '/' + subfolder, regex_filters=regex_filters, delimiter=",") + + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("sets", [0], fast=True), 'Test': odh.isolate_data("sets", [1,2,3,4], fast=True)} + + return data + + +class CIIL_WeaklySupervised(Dataset): + def __init__(self, dataset_folder='CIIL_WeaklySupervised/'): + Dataset.__init__(self, + 1000, + 8, + 'OyMotion gForcePro+ EMG Armband', + 16, + {0: 'Close', 1: 'Open', 2: 'Rest', 3: 'Flexion', 4: 'Extension'}, + '30 min weakly supervised, 1 rep calibration, 14 reps test', + "A weakly supervised environment with sparse supervised calibration.", + 'In Submission') + self.url = "https://unbcloud-my.sharepoint.com/:u:/g/personal/ecampbe2_unb_ca/EaABHYybhfJNslTVcvwPPwgB9WwqlTLCStui30maqY53kw?e=MbboMd" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download_via_onedrive(self.url, self.dataset_folder) + + # supervised odh loading + subjects = [str(i) for i in range(0, 16)] + classes_values = [str(i) for i in range(0,5)] + reps_values = [str(i) for i in range(0,15)] + setting_values = [".csv", ""] # this is arbitrary to get a field that separates WS from S + regex_filters = [ + RegexFilter(left_bound = "", right_bound="", values = setting_values, description='settings'), + RegexFilter(left_bound = "/S", right_bound="/", values = subjects, description='subjects'), + RegexFilter(left_bound = "R", right_bound=".csv", values = reps_values, description='reps'), + RegexFilter(left_bound = "C", right_bound="_R", values = classes_values, description='classes') + ] + odh_s = OfflineDataHandler() + odh_s.get_data(folder_location=self.dataset_folder+"CIIL_WeaklySupervised/", + regex_filters=regex_filters, + delimiter=",") + + # weakly supervised odh loading + subjects = [str(i) for i in range(0, 16)] + reps_values = [str(i) for i in range(3)] + setting_values = ["", ".csv"] # this is arbitrary to get a field that separates WS from S + regex_filters = [ + RegexFilter(left_bound = "", right_bound="", values = setting_values, description='settings'), + RegexFilter(left_bound = "/S", right_bound="/", values = subjects, description='subjects'), + RegexFilter(left_bound = "WS", right_bound=".csv", values = reps_values, description='reps'), + ] + metadata_fetchers = [ + FilePackager(regex_filter=RegexFilter(left_bound="", right_bound="targets.csv", values=["_"], description="classes"), + package_function=lambda x, y: (x.split("WS")[1][0] == y.split("WS")[1][0]) and (Path(x).parent == Path(y).parent) + ) + ] + odh_ws = OfflineDataHandler() + odh_ws.get_data(folder_location=self.dataset_folder+"CIIL_WeaklySupervised/", + regex_filters=regex_filters, + metadata_fetchers=metadata_fetchers, + delimiter=",") + + data = odh_s + odh_ws + if split: + data = {'All': data, + 'Pretrain': odh_ws, + 'Train': odh_s.isolate_data("reps", [0], fast=True), + 'Test': odh_s.isolate_data("reps", list(range(1,15)), fast=True)} + + return data diff --git a/libemg/_datasets/continous_transitions.py b/libemg/_datasets/continous_transitions.py new file mode 100644 index 00000000..316487b3 --- /dev/null +++ b/libemg/_datasets/continous_transitions.py @@ -0,0 +1,63 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler +import h5py +import numpy as np + +class ContinuousTransitions(Dataset): + def __init__(self, dataset_folder="ContinuousTransitions/"): + Dataset.__init__(self, + 2000, + 6, + 'Delsys', + 43, + {0: 'No Motion', 1: 'Wrist Flexion', 2: 'Wrist Extension', 3: 'Wrist Pronation', 4: 'Wrist Supination', 5: 'Hand Close', 6: 'Hand Open'}, + '6 Training (Ramp), 42 Transitions (All combinations of Transitions) x 6 Reps', + "The testing set in this dataset has continuous transitions between classes which is a more realistic offline evaluation standard for myoelectric control.", + "https://ieeexplore.ieee.org/document/10254242") + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + print("Please download the dataset from: ") #TODO: Update + return + + # Training ODH + odh_tr = OfflineDataHandler() + odh_tr.subjects = [] + odh_tr.classes = [] + odh_tr.extra_attributes = ['subjects', 'classes'] + + # Testing ODH + odh_te = OfflineDataHandler() + odh_te.subjects = [] + odh_te.classes = [] + odh_te.extra_attributes = ['subjects', 'classes'] + + for s_i, s in enumerate([2,3,4,5,6,7,8,9,10,11,12,13,14,15,17,18,19,20,21,22,23,25,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47]): + data = h5py.File('ContinuousTransitions/P' + f"{s:02}" + '.hdf5', "r") + cont_labels = data['continuous']['emg']['prompt'][()] + cont_labels = np.hstack([np.ones((1000)) * cont_labels[0], cont_labels[0:len(cont_labels)-1000]]) # Rolling about 0.5s as per Shri's suggestion + cont_emg = data['continuous']['emg']['signal'][()] + cont_chg_idxs = np.insert(np.where(cont_labels[:-1] != cont_labels[1:])[0], 0, -1) + cont_chg_idxs = np.insert(cont_chg_idxs, len(cont_chg_idxs), len(cont_emg)) + for i in range(0, len(cont_chg_idxs)-1): + odh_te.data.append(cont_emg[cont_chg_idxs[i]+1:cont_chg_idxs[i+1]]) + odh_te.classes.append(np.expand_dims(cont_labels[cont_chg_idxs[i]+1:cont_chg_idxs[i+1]]-1, axis=1)) + odh_te.subjects.append(np.ones((len(odh_te.data[-1]), 1)) * s_i) + + ramp_emg = data['ramp']['emg']['signal'][()] + ramp_labels = data['ramp']['emg']['prompt'][()] + r_chg_idxs = np.insert(np.where(ramp_labels[:-1] != ramp_labels[1:])[0], 0, -1) + r_chg_idxs = np.insert(r_chg_idxs, len(r_chg_idxs), len(ramp_emg)) + for i in range(0, len(r_chg_idxs)-1): + odh_tr.data.append(ramp_emg[r_chg_idxs[i]+1:r_chg_idxs[i+1]]) + odh_tr.classes.append(np.expand_dims(ramp_labels[r_chg_idxs[i]+1:r_chg_idxs[i+1]]-1, axis=1)) + odh_tr.subjects.append(np.ones((len(odh_tr.data[-1]), 1)) * s_i) + + odh_all = odh_tr + odh_te + data = odh_all + if split: + data = {'All': odh_all, 'Train': odh_tr, 'Test': odh_te} + + return data diff --git a/libemg/_datasets/dataset.py b/libemg/_datasets/dataset.py new file mode 100644 index 00000000..f43bef77 --- /dev/null +++ b/libemg/_datasets/dataset.py @@ -0,0 +1,55 @@ +import os +from libemg.data_handler import OfflineDataHandler +from onedrivedownloader import download as onedrive_download +# this assumes you have git downloaded (not pygit, but the command line program git) + +class Dataset: + def __init__(self, sampling, num_channels, recording_device, num_subjects, gestures, num_reps, description, citation): + # Every class should have this + self.sampling=sampling + self.num_channels=num_channels + self.recording_device=recording_device + self.num_subjects=num_subjects + self.gestures=gestures + self.num_reps=num_reps + self.description=description + self.citation=citation + + def download(self, url, dataset_name): + clone_command = "git clone " + url + " " + dataset_name + os.system(clone_command) + + def download_via_onedrive(self, url, dataset_name, unzip=True, clean=True): + onedrive_download(url=url, + filename = dataset_name, + unzip=unzip, + clean=clean) + + def remove_dataset(self, dataset_folder): + remove_command = "rm -rf " + dataset_folder + os.system(remove_command) + + def check_exists(self, dataset_folder): + return os.path.exists(dataset_folder) + + def prepare_data(self, split = False): + pass + + def get_info(self): + print(str(self.description) + '\n' + 'Sampling Rate: ' + str(self.sampling) + '\nNumber of Channels: ' + str(self.num_channels) + + '\nDevice: ' + self.recording_device + '\nGestures: ' + str(self.gestures) + '\nNumber of Reps: ' + str(self.num_reps) + '\nNumber of Subjects: ' + str(self.num_subjects) + + '\nCitation: ' + str(self.citation)) + +# given a directory, return a list of files in that directory matching a format +# can be nested +# this is just a handly utility +def find_all_files_of_type_recursively(dir, terminator): + files = os.listdir(dir) + file_list = [] + for file in files: + if file.endswith(terminator): + file_list.append(dir+file) + else: + if os.path.isdir(dir+file): + file_list += find_all_files_of_type_recursively(dir+file+'/',terminator) + return file_list \ No newline at end of file diff --git a/libemg/_datasets/emg_epn612.py b/libemg/_datasets/emg_epn612.py new file mode 100644 index 00000000..7412f25b --- /dev/null +++ b/libemg/_datasets/emg_epn612.py @@ -0,0 +1,96 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler +import pickle +import numpy as np +from libemg.feature_extractor import FeatureExtractor +from libemg.utils import * + +class EMGEPN612(Dataset): + def __init__(self, dataset_file='EMGEPN612.pkl'): + Dataset.__init__(self, + 200, + 8, + 'Myo Armband', + 612, + {0: 'Close', 1: 'Open', 2: 'Rest', 3: 'Flexion', 4: 'Extension'}, + '50 Reps x 306 Users (Train), 25 Reps x 306 Users (Test)', + "A large 612 user dataset for developing cross user models.", + 'https://doi.org/10.5281/zenodo.4421500') + self.url = "https://unbcloud-my.sharepoint.com/:u:/g/personal/ecampbe2_unb_ca/EWf3sEvRxg9HuAmGoBG2vYkBLyFv6UrPYGwAISPDW9dBXw?e=vjCA14" + self.dataset_name = dataset_file + + def prepare_data(self, split = False): + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_name)): + self.download_via_onedrive(self.url, self.dataset_name, unzip=False, clean=False) + + file = open(self.dataset_name, 'rb') + data = pickle.load(file) + + emg = data[0] + labels = data[2] + + odh_tr = OfflineDataHandler() + odh_tr.subjects = [] + odh_tr.classes = [] + odh_tr.reps = [] + tr_reps = [0,0,0,0,0,0] + odh_tr.extra_attributes = ['subjects', 'classes', 'reps'] + for i, e in enumerate(emg['training']): + odh_tr.data.append(e) + odh_tr.classes.append(np.ones((len(e), 1)) * labels['training'][i]) + odh_tr.subjects.append(np.ones((len(e), 1)) * i//300) + odh_tr.reps.append(np.ones((len(e), 1)) * tr_reps[labels['training'][i]]) + tr_reps[labels['training'][i]] += 1 + if i % 300 == 0: + tr_reps = [0,0,0,0,0,0] + odh_te = OfflineDataHandler() + odh_te.subjects = [] + odh_te.classes = [] + odh_te.reps = [] + te_reps = [0,0,0,0,0,0] + odh_te.extra_attributes = ['subjects', 'classes', 'reps'] + for i, e in enumerate(emg['testing']): + odh_te.data.append(e) + odh_te.classes.append(np.ones((len(e), 1)) * labels['testing'][i]) + odh_te.subjects.append(np.ones((len(e), 1)) * (i//150 + 306)) + odh_te.reps.append(np.ones((len(e), 1)) * te_reps[labels['training'][i]]) + te_reps[labels['training'][i]] += 1 + if i % 150 == 0: + te_reps = [0,0,0,0,0,0] + + # odh_tr = self._update_odh(odh_tr) + # odh_te = self._update_odh(odh_te) + odh_all = odh_tr + odh_te + + data = odh_all + if split: + data = {'All': odh_all, 'Train': odh_tr, 'Test': odh_te} + + return data + + def _update_odh(self, odh): + pass + # fe = FeatureExtractor() + # for i_e, e in enumerate(odh.data): + # if odh.classes[i_e][0][0] == 0: + # # It is no motion and we need to crop it (make datset even) + # odh.data[i_e] = e[100:200] + # odh.subjects[i_e] = odh.subjects[i_e][100:200] + # odh.classes[i_e] = odh.classes[i_e][100:200] + # odh.reps[i_e] = odh.reps[i_e][100:200] + # else: + # # It is an active class and we are croppign it + # if len(e) > 100: + # windows = get_windows(e, 20, 5) + # feats = fe.extract_features(['MAV'], windows, array=True) + # mval = np.argmax(np.mean(feats, axis=1)) * 5 + # max_idx = min([len(e), mval + 50]) + # min_idx = max([0, mval - 50]) + # odh.data[i_e] = e[min_idx:max_idx] + # odh.subjects[i_e] = odh.subjects[i_e][min_idx:max_idx] + # odh.classes[i_e] = odh.classes[i_e][min_idx:max_idx] + # odh.reps[i_e] = odh.reps[i_e][min_idx:max_idx] + # return odh + + \ No newline at end of file diff --git a/libemg/_datasets/fors_emg.py b/libemg/_datasets/fors_emg.py new file mode 100644 index 00000000..24adea4f --- /dev/null +++ b/libemg/_datasets/fors_emg.py @@ -0,0 +1,52 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter +import scipy.io +import numpy as np + +class FORSEMG(Dataset): + def __init__(self, dataset_folder='FORS-EMG/'): + Dataset.__init__(self, + 985, + 8, + 'Experimental Device', + 19, + {0: 'Thump Up', 1: 'Index', 2: 'Right Angle', 3: 'Peace', 4: 'Index Little', 5: 'Thumb Little', 6: 'Hand Close', 7: 'Hand Open', 8: 'Wrist Flexion', 9: 'Wrist Extension', 10: 'Ulnar Deviation', 11: 'Radial Deviation'}, + '5 Train, 10 Test (2 Forarm Orientations x 5 Reps)', + "FORS-EMG: Twelve gestures elicited in three forearm orientations (neutral, pronation, and supination).", + 'https://arxiv.org/abs/2409.07484t') + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + print("Please download the dataset from: https://www.kaggle.com/datasets/ummerummanchaity/fors-emg-a-novel-semg-dataset?resource=download") + return + + odh = OfflineDataHandler() + odh.subjects = [] + odh.classes = [] + odh.reps = [] + odh.orientation = [] + odh.extra_attributes = ['subjects', 'classes', 'reps', 'orientation'] + + for s in range(1, 20): + for g_i, g in enumerate(['Thumb_UP', 'Index', 'Right_Angle', 'Peace', 'Index_Little', 'Thumb_Little', 'Hand_Close', 'Hand_Open', 'Wrist_Flexion', 'Wrist_Extension', 'Radial_Deviation']): + for r in [1,2,3,4,5]: + for o_i, o in enumerate(['Rest', 'Pronation', 'Supination']): + try: + mat = scipy.io.loadmat('FORS-EMG/Subject' + str(s) + '/' + o + '/' + g + '-' + str(r) + '.mat') + except: + o = o.lower() + mat = scipy.io.loadmat('FORS-EMG/Subject' + str(s) + '/' + o + '/' + g + '-' + str(r) + '.mat') + + odh.data.append(mat['value'].T) + odh.classes.append(np.ones((len(odh.data[-1]), 1)) * g_i) + odh.subjects.append(np.ones((len(odh.data[-1]), 1)) * s-1) + odh.reps.append(np.ones((len(odh.data[-1]), 1)) * r-1) + odh.orientation.append(np.ones((len(odh.data[-1]), 1)) * o_i) + + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data('orientation', [0], fast=True), 'Test': odh.isolate_data('orientation', [1,2], fast=True)} + + return data diff --git a/libemg/_datasets/fougner_lp.py b/libemg/_datasets/fougner_lp.py new file mode 100644 index 00000000..df55c7c5 --- /dev/null +++ b/libemg/_datasets/fougner_lp.py @@ -0,0 +1,40 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter + +class FougnerLP(Dataset): + def __init__(self, dataset_folder="LimbPosition/"): + Dataset.__init__(self, + 1000, + 8, + 'BE328 by Liberating Technologies, Inc.', + 12, + {0: 'Wrist Flexion', 1: 'Wrist Extension', 2: 'Pronation', 3: 'Supination', 4: 'Hand Open', 5: 'Power Grip', 6: 'Pinch Grip', 7: 'Rest'}, + '10 Reps (Train), 10 Reps x 4 Positions', + "A limb position dataset (with 5 static limb positions).", + "https://ieeexplore.ieee.org/document/5985538") + self.url = "https://github.com/libemg/LimbPosition" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + subjects_values = [str(i) for i in range(1,13)] + position_values = ["P1", "P2", "P3", "P4", "P5"] + classes_values = ["1", "2", "3", "4", "5", "8", "9", "12"] + reps_values = ["1","2","3","4","5","6","7","8","9","10"] + + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + regex_filters = [ + RegexFilter(left_bound="/S", right_bound="/",values=subjects_values, description='subjects'), + RegexFilter(left_bound = "_", right_bound="_R", values = position_values, description='positions'), + RegexFilter(left_bound = "_C", right_bound="_P", values = classes_values, description='classes'), + RegexFilter(left_bound = "_R", right_bound=".txt", values = reps_values, description='reps'), + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder + 'FougnerLimbPosition/', regex_filters=regex_filters, delimiter=",") + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("positions", [0], fast=True), 'Test': odh.isolate_data("positions", list(range(1, len(position_values))), fast=True)} + + return data \ No newline at end of file diff --git a/libemg/_datasets/grab_myo.py b/libemg/_datasets/grab_myo.py new file mode 100644 index 00000000..be3d88ed --- /dev/null +++ b/libemg/_datasets/grab_myo.py @@ -0,0 +1,93 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter + +class GRABMyo(Dataset): + """ + By default this just uses the 16 forearm electrodes. + """ + def __init__(self, dataset_folder='GRABMyo/', baseline=False): + split = '7 Train, 14 Test (2 Seperate Days x 7 Reps)' + if baseline: + split = '5 Train, 2 Test (Basline)' + Dataset.__init__(self, + 2048, + 16, + 'EMGUSB2+ device (OT Bioelletronica, Italy)', + 43, + {0: 'Lateral Prehension', 1: 'Thumb Adduction', 2: 'Thumb and Little Finger Opposition', 3: 'Thumb and Index Finger Opposition', 4: 'Thumb and Index Finger Extension', 5: 'Thumb and Little Finger Extension', 6: 'Index and Middle Finger Extension', + 7: 'Little Finger Extension', 8: 'Index Finger Extension', 9: 'Thumb Finger Extension', 10: 'Wrist Extension', 11: 'Wrist Flexion', 12: 'Forearm Supination', 13: 'Forearm Pronation', 14: 'Hand Open', 15: 'Hand Close', 16: 'Rest'}, + split, + "GrabMyo: A large cross session dataset including 17 gestures elicited across 3 seperate sessions.", + 'https://www.nature.com/articles/s41597-022-01836-y') + self.dataset_folder = dataset_folder + + def check_if_exist(self): + if (not self.check_exists(self.dataset_folder)): + print("Please download the GRABMyo dataset from: https://physionet.org/content/grabmyo/1.0.2/") + return + print('\nPlease cite: ' + self.citation+'\n') + + +class GRABMyoCrossDay(GRABMyo): + def __init__(self, dataset_folder="GRABMyo"): + GRABMyo.__init__(self, dataset_folder=dataset_folder, baseline=False) + + def prepare_data(self, split = False): + self.check_if_exist() + + sessions = ["1", "2", "3"] + subjects = [str(i) for i in range(1,44)] + classes_values = ["1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17"] + reps_values = ["1","2","3","4","5","6","7"] + + regex_filters = [ + RegexFilter(left_bound = "session", right_bound="_", values = sessions, description='sessions'), + RegexFilter(left_bound = "_gesture", right_bound="_", values = classes_values, description='classes'), + RegexFilter(left_bound = "trial", right_bound=".hea", values = reps_values, description='reps'), + RegexFilter(left_bound="participant", right_bound="_",values=subjects, description='subjects') + ] + + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") + + forearm_data = odh.isolate_channels(list(range(0,16))) + train_data = forearm_data.isolate_data('sessions', [0], fast=True) + test_data = forearm_data.isolate_data('sessions', [1,2], fast=True) + + data = forearm_data + if split: + data = {'All': forearm_data, 'Train': train_data, 'Test': test_data} + + return data + +class GRABMyoBaseline(GRABMyo): + def __init__(self, dataset_folder="GRABMyo"): + GRABMyo.__init__(self, dataset_folder=dataset_folder, baseline=True) + + def prepare_data(self, split = False): + self.check_if_exist() + + sessions = ["1"] + subjects = [str(i) for i in range(1,44)] + classes_values = ["1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17"] + reps_values = ["1","2","3","4","5","6","7"] + + regex_filters = [ + RegexFilter(left_bound = "session", right_bound="_", values = sessions, description='session'), + RegexFilter(left_bound = "_gesture", right_bound="_", values = classes_values, description='classes'), + RegexFilter(left_bound = "trial", right_bound=".hea", values = reps_values, description='reps'), + RegexFilter(left_bound="participant", right_bound="_",values=subjects, description='subjects') + ] + + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") + + forearm_data = odh.isolate_channels(list(range(0,16))) + train_data = forearm_data.isolate_data('reps', [0,1,2,3,4], fast=True) + test_data = forearm_data.isolate_data('reps', [5,6], fast=True) + + data = forearm_data + if split: + data = {'All': forearm_data, 'Train': train_data, 'Test': test_data} + + return data \ No newline at end of file diff --git a/libemg/_datasets/hyser.py b/libemg/_datasets/hyser.py new file mode 100644 index 00000000..ead3542f --- /dev/null +++ b/libemg/_datasets/hyser.py @@ -0,0 +1,328 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from pathlib import Path +from typing import Sequence + +import numpy as np + +from libemg.data_handler import RegexFilter, FilePackager, OfflineDataHandler, MetadataFetcher +from libemg._datasets.dataset import Dataset + + +class _Hyser(Dataset, ABC): + def __init__(self, gestures, num_reps, description, dataset_folder, analysis = 'baseline', subjects = None): + super().__init__( + sampling=2048, + num_channels=256, + recording_device='OT Bioelettronica Quattrocento', + num_subjects=20, + gestures=gestures, + num_reps=num_reps, + description=description, + citation='https://doi.org/10.13026/ym7v-bh53' + ) + if subjects is None: + subjects = [str(idx + 1).zfill(2) for idx in range(self.num_subjects)] # +1 due to Python indexing + + self.url = 'https://www.physionet.org/content/hd-semg/1.0.0/' + self.dataset_folder = dataset_folder + self.analysis = analysis + self.subjects = subjects + + @property + def common_regex_filters(self): + sessions_values = ['1', '2'] if self.analysis == 'sessions' else ['1'] # only grab first session unless both are desired + filters = [ + RegexFilter(left_bound='subject', right_bound='_session', values=self.subjects, description='subjects', return_value=True), + RegexFilter(left_bound='_session', right_bound='/', values=sessions_values, description='sessions') + ] + return filters + + def prepare_data(self, split = False): + if (not self.check_exists(self.dataset_folder)): + raise FileNotFoundError(f"Didn't find Hyser data in {self.dataset_folder} directory. Please download the dataset and \ + store it in the appropriate directory before running prepare_data(). See {self.url} for download details.") + return self._prepare_data_helper(split=split) + + @abstractmethod + def _prepare_data_helper(self, split = False) -> dict | OfflineDataHandler: + ... + + +class Hyser1DOF(_Hyser): + def __init__(self, dataset_folder: str = 'Hyser1DOF', analysis: str = 'baseline', subjects: Sequence[str] | None = None): + """1 degree of freedom (DOF) Hyser dataset. + + Parameters + ---------- + dataset_folder: str, default='Hyser1DOF' + Directory that contains Hyser 1 DOF dataset. + analysis: str, default='baseline' + Determines which type of data will be extracted and considered train/test splits. If 'baseline', only grabs data from the first session and splits based on + reps. If 'sessions', grabs data from both sessions and return the first session as train and the second session as test. + subjects: Sequence[str] or None, default=None + Subjects to parse (e.g., ['01', '03', '10']). If None, parses all participants. Defaults to None. + """ + gestures = {1: 'Thumb', 2: 'Index', 3: 'Middle', 4: 'Ring', 5: 'Little'} + description = 'Hyser 1 DOF dataset. Includes within-DOF finger movements. Ground truth finger forces are recorded for use in finger force regression.' + super().__init__(gestures=gestures, num_reps=3, description=description, dataset_folder=dataset_folder, analysis=analysis, subjects=subjects) + + def _prepare_data_helper(self, split = False): + filename_filters = deepcopy(self.common_regex_filters) + filename_filters.append(RegexFilter(left_bound='_sample', right_bound='.hea', values=[str(idx + 1) for idx in range(self.num_reps)], description='reps')) + filename_filters.append(RegexFilter(left_bound='_finger', right_bound='_sample', values=['1', '2', '3', '4', '5'], description='finger')) + + regex_filters = deepcopy(filename_filters) + regex_filters.append(RegexFilter(left_bound='1dof_', right_bound='_finger', values=['raw'], description='data_type')) + + metadata_fetchers = [ + FilePackager(RegexFilter(left_bound='/1dof_', right_bound='_finger', values=['force'], description='labels'), + package_function=filename_filters, load='p_signal') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers) + data = odh + if split: + if self.analysis == 'sessions': + data = {'All': odh, 'Train': odh.isolate_data('sessions', [0], fast=True), 'Test': odh.isolate_data('sessions', [1], fast=True)} + elif self.analysis == 'baseline': + data = {'All': odh, 'Train': odh.isolate_data('reps', [0, 1], fast=True), 'Test': odh.isolate_data('reps', [2], fast=True)} + else: + raise ValueError(f"Unexpected value for analysis. Supported values are sessions, baseline. Got: {self.analysis}.") + return data + + +class HyserNDOF(_Hyser): + def __init__(self, dataset_folder: str = 'HyserNDOF', analysis: str = 'baseline', subjects: Sequence[str] | None = None): + """N degree of freedom (DOF) Hyser dataset. + + Parameters + ---------- + dataset_folder: str, default='HyserNDOF' + Directory that contains Hyser N DOF dataset. + analysis: str, default='baseline' + Determines which type of data will be extracted and considered train/test splits. If 'baseline', only grabs data from the first session and splits based on + reps. If 'sessions', grabs data from both sessions and return the first session as train and the second session as test. + subjects: Sequence[str] or None, default=None + Subjects to parse (e.g., ['01', '03', '10']). If None, parses all participants. Defaults to None. + """ + # TODO: Add a 'regression' flag... maybe add a 'DOFs' parameter instead of just gestures? + gestures = {1: 'Thumb', 2: 'Index', 3: 'Middle', 4: 'Ring', 5: 'Little'} + description = 'Hyser N DOF dataset. Includes combined finger movements. Ground truth finger forces are recorded for use in finger force regression.' + super().__init__(gestures=gestures, num_reps=2, description=description, dataset_folder=dataset_folder, analysis=analysis, subjects=subjects) + self.finger_combinations = { + 1: 'Thumb + Index', + 2: 'Thumb + Middle', + 3: 'Thumg + Ring', + 4: 'Thumb + Little', + 5: 'Index + Middle', + 6: 'Thumb + Index + Middle', + 7: 'Index + Middle + Ring', + 8: 'Middle + Ring + Little', + 9: 'Index + Middle + Ring + Little', + 10: 'All Fingers', + 11: 'Thumb + Index (Opposing)', + 12: 'Thumb + Middle (Opposing)', + 13: 'Thumg + Ring (Opposing)', + 14: 'Thumb + Little (Opposing)', + 15: 'Index + Middle (Opposing)' + } + + def _prepare_data_helper(self, split = False) -> dict | OfflineDataHandler: + filename_filters = deepcopy(self.common_regex_filters) + filename_filters.append(RegexFilter(left_bound='_sample', right_bound='.hea', values=[str(idx + 1) for idx in range(self.num_reps)], description='reps')) + filename_filters.append(RegexFilter(left_bound='_combination', right_bound='_sample', values=[str(idx + 1) for idx in range(len(self.finger_combinations))], description='finger_combinations')) + + regex_filters = deepcopy(filename_filters) + regex_filters.append(RegexFilter(left_bound='/ndof_', right_bound='_combination', values=['raw'], description='data_type')) + + metadata_fetchers = [ + FilePackager(RegexFilter(left_bound='/ndof_', right_bound='_combination', values=['force'], description='labels'), + package_function=filename_filters, load='p_signal') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers) + data = odh + if split: + if self.analysis == 'sessions': + data = {'All': odh, 'Train': odh.isolate_data('sessions', [0], fast=True), 'Test': odh.isolate_data('sessions', [1], fast=True)} + elif self.analysis == 'baseline': + data = {'All': odh, 'Train': odh.isolate_data('reps', [0], fast=True), 'Test': odh.isolate_data('reps', [1], fast=True)} + else: + raise ValueError(f"Unexpected value for analysis. Supported values are sessions, baseline. Got: {self.analysis}.") + + return data + + +class HyserRandom(_Hyser): + def __init__(self, dataset_folder: str = 'HyserRandom', analysis: str = 'baseline', subjects: Sequence[str] | None = None): + """Random task (DOF) Hyser dataset. + + Parameters + ---------- + dataset_folder: str, default='HyserRandom' + Directory that contains Hyser random task dataset. + analysis: str, default='baseline' + Determines which type of data will be extracted and considered train/test splits. If 'baseline', only grabs data from the first session and splits based on + reps. If 'sessions', grabs data from both sessions and return the first session as train and the second session as test. + subjects: Sequence[str] or None, default=None + Subjects to parse (e.g., ['01', '03', '10']). If None, parses all participants. Defaults to None. + """ + gestures = {1: 'Thumb', 2: 'Index', 3: 'Middle', 4: 'Ring', 5: 'Little'} + description = 'Hyser random dataset. Includes random motions performed by users. Ground truth finger forces are recorded for use in finger force regression.' + super().__init__(gestures=gestures, num_reps=5, description=description, dataset_folder=dataset_folder, analysis=analysis, subjects=subjects) + self.subjects = [s for s in self.subjects if s != '10'] + + + def _prepare_data_helper(self, split = False) -> dict | OfflineDataHandler: + filename_filters = deepcopy(self.common_regex_filters) + filename_filters.append(RegexFilter(left_bound='_sample', right_bound='.hea', values=[str(idx + 1) for idx in range(self.num_reps)], description='reps')) + + regex_filters = deepcopy(filename_filters) + regex_filters.append(RegexFilter(left_bound='/random_', right_bound='_sample', values=['raw'], description='data_type')) + + metadata_fetchers = [ + FilePackager(RegexFilter(left_bound='/random_', right_bound='_sample', values=['force'], description='labels'), + package_function=filename_filters, load='p_signal') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers) + data = odh + if split: + if self.analysis == 'sessions': + data = {'All': odh, 'Train': odh.isolate_data('sessions', [0], fast=True), 'Test': odh.isolate_data('sessions', [1], fast=True)} + elif self.analysis == 'baseline': + data = {'All': odh, 'Train': odh.isolate_data('reps', [0, 1, 2], fast=True), 'Test': odh.isolate_data('reps', [3, 4], fast=True)} + else: + raise ValueError(f"Unexpected value for analysis. Supported values are sessions, baseline. Got: {self.analysis}.") + + return data + + +class _PRLabelsFetcher(MetadataFetcher): + def __init__(self): + super().__init__(description='classes') + self.sample_regex = RegexFilter(left_bound='_sample', right_bound='.hea', values=[str(idx + 1) for idx in range(204)], description='samples') + + def _get_labels(self, filename): + label_filename_map = { + 'dynamic': 'label_dynamic.txt', + 'maintenance': 'label_maintenance.txt' + } + matches = [] + for task_type, labels_file in label_filename_map.items(): + if task_type in filename: + matches.append(labels_file) + + assert len(matches) == 1, f"Expected a single label file for this file, but got {len(matches)}. Got filename: {filename}. Filename should contain either 'dynamic' or 'maintenance'." + + labels_file = matches[0] + parent = Path(filename).absolute().parent + labels_file = Path(parent, labels_file).as_posix() + return np.loadtxt(labels_file, delimiter=',', dtype=int) + + def __call__(self, filename, file_data, all_files): + labels = self._get_labels(filename) + sample_idx = self.sample_regex.get_metadata(filename) + assert isinstance(sample_idx, int), f"Expected index, but got value of type {type(sample_idx)}." + return labels[sample_idx] - 1 # -1 to produce 0-indexed labels + + +class _PRRepFetcher(_PRLabelsFetcher): + def __init__(self): + super().__init__() + self.description = 'reps' + + def __call__(self, filename, file_data, all_files): + label = super().__call__(filename, file_data, all_files) + 1 # +1 b/c this returns 0-indexed labels, but the files are 1-indexed + labels = self._get_labels(filename) + same_label_mask = np.where(labels == label)[0] + sample_idx = self.sample_regex.get_metadata(filename) + rep_idx = list(same_label_mask).index(sample_idx) + if 'dynamic' in filename: + # Each trial is 3 dynamic reps, 1 maintenance rep + rep_idx = rep_idx // 3 + + assert rep_idx <= 1, f"Rep values should be 0 or 1 (2 total reps). Got: {rep_idx}." + return np.array(rep_idx) + + +class HyserPR(_Hyser): + def __init__(self, dataset_folder: str = 'HyserPR', analysis: str = 'baseline', subjects: Sequence[str] | None = None): + """Pattern recognition (PR) Hyser dataset. + + Parameters + ---------- + dataset_folder: str, default='HyserPR' + Directory that contains Hyser PR dataset. + analysis: str, default='baseline' + Determines which type of data will be extracted and considered train/test splits. If 'baseline', only grabs data from the first session and splits based on + reps. If 'sessions', grabs data from both sessions and return the first session as train and the second session as test. + subjects: Sequence[str] or None, default=None + Subjects to parse (e.g., ['01', '03', '10']). If None, parses all participants. Defaults to None. + """ + gestures = { + 1: 'Thumb Extension', + 2: 'Index Finger Extension', + 3: 'Middle Finger Extension', + 4: 'Ring Finger Extension', + 5: 'Little Finger Extension', + 6: 'Wrist Flexion', + 7: 'Wrist Extension', + 8: 'Wrist Radial', + 9: 'Wrist Ulnar', + 10: 'Wrist Pronation', + 11: 'Wrist Supination', + 12: 'Extension of Thumb and Index Fingers', + 13: 'Extension of Index and Middle Fingers', + 14: 'Wrist Flexion Combined with Hand Close', + 15: 'Wrist Extension Combined with Hand Close', + 16: 'Wrist Radial Combined with Hand Close', + 17: 'Wrist Ulnar Combined with Hand Close', + 18: 'Wrist Pronation Combined with Hand Close', + 19: 'Wrist Supination Combined with Hand Close', + 20: 'Wrist Flexion Combined with Hand Open', + 21: 'Wrist Extension Combined with Hand Open', + 22: 'Wrist Radial Combined with Hand Open', + 23: 'Wrist Ulnar Combined with Hand Open', + 24: 'Wrist Pronation Combined with Hand Open', + 25: 'Wrist Supination Combined with Hand Open', + 26: 'Extension of Thumb, Index and Middle Fingers', + 27: 'Extension of Index, Middle and Ring Fingers', + 28: 'Extension of Middle, Ring and Little Fingers', + 29: 'Extension of Index, Middle, Ring and Little Fingers', + 30: 'Hand Close', + 31: 'Hand Open', + 32: 'Thumb and Index Fingers Pinch', + 33: 'Thumb, Index and Middle Fingers Pinch', + 34: 'Thumb and Middle Fingers Pinch' + } + description = 'Hyser pattern recognition (PR) dataset. Includes dynamic and maintenance tasks for 34 hand gestures.' + super().__init__(gestures=gestures, num_reps=2, description=description, dataset_folder=dataset_folder, analysis=analysis, subjects=subjects) # num_reps=2 b/c 2 trials + self.subjects = [s for s in self.subjects if s not in ('03', '11')] # subjects 3 and 11 are missing classes + + def _prepare_data_helper(self, split = False) -> dict | OfflineDataHandler: + filename_filters = deepcopy(self.common_regex_filters) + filename_filters.append(RegexFilter(left_bound='_sample', right_bound='.hea', values=[str(idx + 1) for idx in range(204)], description='samples')) # max # of dynamic tasks + filename_filters.append(RegexFilter(left_bound='/', right_bound='_', values=['dynamic', 'maintenance'], description='tasks')) + + regex_filters = deepcopy(filename_filters) + regex_filters.append(RegexFilter(left_bound='_', right_bound='_sample', values=['raw'], description='data_type')) + + metadata_fetchers = [ + _PRLabelsFetcher(), + _PRRepFetcher() + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers) + + data = odh + if split: + if self.analysis == 'sessions': + data = {'All': odh, 'Train': odh.isolate_data('sessions', [0], fast=True), 'Test': odh.isolate_data('sessions', [1], fast=True)} + elif self.analysis == 'baseline': + data = {'All': odh, 'Train': odh.isolate_data('reps', [0], fast=True), 'Test': odh.isolate_data('reps', [1], fast=True)} + else: + raise ValueError(f"Unexpected value for analysis. Supported values are sessions, baseline. Got: {self.analysis}.") + + return data diff --git a/libemg/_datasets/intensity.py b/libemg/_datasets/intensity.py new file mode 100644 index 00000000..dbecd855 --- /dev/null +++ b/libemg/_datasets/intensity.py @@ -0,0 +1,40 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter + +class ContractionIntensity(Dataset): + def __init__(self, dataset_folder="ContractionIntensity/"): + Dataset.__init__(self, + 1000, + 8, + 'BE328 by Liberating Technologies, Inc', + 10, + {0: "No Motion", 1: "Wrist Flexion", 2: "Wrist Extension", 3: "Wrist Pronation", 4: "Wrist Supination", 5: "Chuck Grip", 6: "Hand Open"}, + '4 Ramp Reps (Train), 4 Reps x 20%, 30%, 40%, 50%, 60%, 70%, 80%, MVC (Test)', + "A contraction intensity dataset.", + "https://pubmed.ncbi.nlm.nih.gov/23894224/") + self.url = "https://github.com/libemg/ContractionIntensity" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + subjects_values = [str(i) for i in range(1,11)] + intensity_values = ["Ramp", "20P", "30P", "40P", "50P", "60P", "70P", "80P", "MVC"] + classes_values = [str(i) for i in range(1,8)] + reps_values = ["1","2","3","4"] + + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + regex_filters = [ + RegexFilter(left_bound="/S", right_bound="/",values=subjects_values, description='subjects'), + RegexFilter(left_bound = "_", right_bound="_C", values = intensity_values, description='intensities'), + RegexFilter(left_bound = "_C", right_bound="_R", values = classes_values, description='classes'), + RegexFilter(left_bound = "_R", right_bound=".csv", values = reps_values, description='reps'), + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("intensities", [0], fast=True), 'Test': odh.isolate_data("intensities", list(range(1, len(intensity_values))), fast=True)} + + return data \ No newline at end of file diff --git a/libemg/_datasets/kaufmann_md.py b/libemg/_datasets/kaufmann_md.py new file mode 100644 index 00000000..846e20a5 --- /dev/null +++ b/libemg/_datasets/kaufmann_md.py @@ -0,0 +1,40 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter + +class KaufmannMD(Dataset): + def __init__(self, dataset_folder="MultiDay/"): + Dataset.__init__(self, + 2048, + 4, + 'MindMedia', + 1, + {0: "No Motion", 1:"Wrist Extension", 2:"Wrist Flexion", 3:"Wrist Adduction", + 4:"Wrist Abduction", 5:"Wrist Supination", 6:"Wrist Pronation", 7:"Hand Open", + 8:"Hand Closed", 9:"Key Grip", 10:"Index Point"}, + '1 rep per day, 120 days total. 60/60 train-test split', + "A single subject, multi-day (120) collection.", + "https://ieeexplore.ieee.org/document/5627288") + self.url = "https://github.com/LibEMG/MultiDay" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + subjects_values = ["0"] + day_values = [str(i) for i in range(1,122)] + classes_values = [str(i) for i in range(11)] + + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + regex_filters = [ + RegexFilter(left_bound="/S", right_bound="_D",values=subjects_values, description='subjects'), + RegexFilter(left_bound = "_D", right_bound="_C", values = day_values, description='days'), + RegexFilter(left_bound = "_C", right_bound=".csv", values = classes_values, description='classes'), + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=" ") + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("days", list(range(60)), fast=True), 'Test': odh.isolate_data("days", list(range(60,121)), fast=True)} + + return data \ No newline at end of file diff --git a/libemg/_datasets/myodisco.py b/libemg/_datasets/myodisco.py new file mode 100644 index 00000000..8674e704 --- /dev/null +++ b/libemg/_datasets/myodisco.py @@ -0,0 +1,73 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter +from libemg.utils import * +from libemg.feature_extractor import FeatureExtractor + +class MyoDisCo(Dataset): + def __init__(self, dataset_folder="MyoDisCo/", cross_day=False): + self.cross_day = cross_day + desc = 'The MyoDisCo dataset which includes both the across day and limb position confounds. (Limb Position Version)' + if self.cross_day: + desc = 'The MyoDisCo dataset which includes both the across day and limb position confounds. (Cross Day Version)' + Dataset.__init__(self, + 200, + 8, + 'Myo Armband', + 14, + {0: "Wrist Extension", 1: "Finger Gun", 2: "Wrist Flexion", 3: "Hand Close", 4: "Hand Open", 5: "Thumbs Up", 6: "Rest"}, + '20 (Train) and 20 (Test) - Each gesture ~0.5s', + desc, + "https://iopscience.iop.org/article/10.1088/1741-2552/ad4915/meta") + self.url = "https://github.com/libemg/MyoDisCo" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + + sets_values = ['day1', 'day2', 'positions'] + subjects_value = [str(i) for i in range(1,15)] + classes_values = ["1","2","3","4","5","8","9"] + reps_values = [str(i) for i in range(0,20)] + regex_filters = [ + RegexFilter(left_bound = "/", right_bound="/", values = sets_values, description='sets'), + RegexFilter(left_bound = "C_", right_bound="_EMG", values = classes_values, description='classes'), + RegexFilter(left_bound = "R_", right_bound="_C", values = reps_values, description='reps'), + RegexFilter(left_bound="S", right_bound="/",values=subjects_value, description='subjects') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") + + fe = FeatureExtractor() + # We need to parse each item to remove no motion + for i, d in enumerate(odh.data): + w = get_windows(d, 20, 5) + mav = fe.extract_features(['MAV'], w, array=True) + if odh.classes[i][0][0] == 6: + odh.data[i] = d[100:200] + else: + mval = np.argmax(np.mean(mav, axis=1)) * 5 + max_idx = min([len(d), mval + 50]) + min_idx = max([0, mval - 50]) + odh.data[i] = d[min_idx:max_idx] + + odh.sets[i] = np.ones((len(odh.data[i]), 1)) * odh.sets[i][0][0] + odh.classes[i] = np.ones((len(odh.data[i]), 1)) * odh.classes[i][0][0] + odh.reps[i] = np.ones((len(odh.data[i]), 1)) * odh.reps[i][0][0] + odh.subjects[i] = np.ones((len(odh.data[i]), 1)) * odh.subjects[i][0][0] + + + if self.cross_day: + odh_train = odh.isolate_data('sets', [0], fast=True) + odh_test = odh.isolate_data('sets', [1], fast=True) + else: + odh_train = odh.isolate_data('sets', [1], fast=True) + odh_test = odh.isolate_data('sets', [2], fast=True) + + data = odh + if split: + data = {'All': odh, 'Train': odh_train, 'Test': odh_test} + + return data diff --git a/libemg/_datasets/nina_pro.py b/libemg/_datasets/nina_pro.py new file mode 100644 index 00000000..d7299376 --- /dev/null +++ b/libemg/_datasets/nina_pro.py @@ -0,0 +1,223 @@ +from pathlib import Path + +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter, ColumnFetcher +import os +import scipy.io as sio +import zipfile +import numpy as np + +def find_all_files_of_type_recursively(dir, terminator): + files = os.listdir(dir) + file_list = [] + for file in files: + if file.endswith(terminator): + file_list.append(dir+file) + else: + if os.path.isdir(dir+file): + file_list += find_all_files_of_type_recursively(dir+file+'/',terminator) + return file_list + +class Ninapro(Dataset): + def __init__(self, + sampling, num_channels, recording_device, num_subjects, gestures, num_reps, description, citation, + dataset_folder="Ninapro"): + # downloading the Ninapro dataset is not supported (no permission given from the authors)' + # however, you can download it from http://ninapro.hevs.ch/DB8 + # the subject zip files should be placed at: /NinaproDB8/DB8_s#.zip + Dataset.__init__(self, sampling, num_channels, recording_device, num_subjects, gestures, num_reps, description, citation) + self.dataset_folder = dataset_folder + self.exercise_step = [] + + def convert_to_compatible(self): + # get the zip files (original format they're downloaded in) + zip_files = find_all_files_of_type_recursively(self.dataset_folder,".zip") + # unzip the files -- if any are there (successive runs skip this) + for zip_file in zip_files: + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(zip_file[:-4]+'/') + os.remove(zip_file) + # get the mat files (the files we want to convert to csv) + mat_files = find_all_files_of_type_recursively(self.dataset_folder,".mat") + for mat_file in mat_files: + self.convert_to_csv(mat_file) + + def convert_to_csv(self, mat_file): + # read the mat file + mat_file = mat_file.replace("\\", "/") + mat_dir = mat_file.split('/') + mat_dir = os.path.join(*mat_dir[:-1],"") + mat = sio.loadmat(mat_file) + # get the data + exercise = int(mat_file.split('_')[-1][1]) + exercise_offset = self.exercise_step[exercise-1] # 0 reps already included + data = mat['emg'] + restimulus = mat['restimulus'] + rerepetition = mat['rerepetition'] + try: + cyberglove_data = mat['glove'] + cyberglove_directory = 'cyberglove' + except KeyError: + # No cyberglove data + cyberglove_data = None + cyberglove_directory = '' + if data.shape[0] != restimulus.shape[0]: # this happens in some cases + min_shape = min([data.shape[0], restimulus.shape[0]]) + data = data[:min_shape,:] + restimulus = restimulus[:min_shape,] + rerepetition = rerepetition[:min_shape,] + if cyberglove_data is not None: + cyberglove_data = cyberglove_data[:min_shape,] + # remove 0 repetition - collection buffer + remove_mask = (rerepetition != 0).squeeze() + data = data[remove_mask,:] + restimulus = restimulus[remove_mask] + rerepetition = rerepetition[remove_mask] + if cyberglove_data is not None: + cyberglove_data = cyberglove_data[remove_mask, :] + # important little not here: + # the "rest" really is only the rest between motions, not a dedicated rest class. + # there will be many more rest repetitions (as it is between every class) + # so usually we really care about classifying rest as its important (most of the time we do nothing) + # but for this dataset it doesn't make sense to include (and not its just an offline showcase of the library) + # I encourage you to plot the restimulus to see what I mean. -> plt.plot(restimulus) + # so we remove the rest class too + remove_mask = (restimulus != 0).squeeze() + data = data[remove_mask,:] + restimulus = restimulus[remove_mask] + rerepetition = rerepetition[remove_mask] + if cyberglove_data is not None: + cyberglove_data = cyberglove_data[remove_mask, :] + tail = 0 + while tail < data.shape[0]-1: + rep = rerepetition[tail][0] # remove the 1 offset (0 was the collection buffer) + motion = restimulus[tail][0] # remove the 1 offset (0 was between motions "rest") + # find head + head = np.where(rerepetition[tail:] != rep)[0] + if head.shape == (0,): # last segment of data + head = data.shape[0] -1 + else: + head = head[0] + tail + if cyberglove_data is not None: + # Combine cyberglove and EMG data + data_for_file = np.concatenate((data[tail:head, :], cyberglove_data[tail:head, :]), axis=1) + else: + data_for_file = data[tail:head,:] + + # downsample to 1kHz from 2kHz using decimation + data_for_file = data_for_file[::2, :] + # write to csv + csv_file = Path(mat_dir, cyberglove_directory, f"C{motion - 1}R{rep - 1 + exercise_offset}.csv") + csv_file.parent.mkdir(parents=True, exist_ok=True) + np.savetxt(csv_file, data_for_file, delimiter=',') + tail = head + os.remove(mat_file) + + +class NinaproDB2(Ninapro): + def __init__(self, dataset_folder="NinaProDB2/", use_cyberglove: bool = False): + Ninapro.__init__(self, + 2000, + 12, + 'Delsys', + 40, + {0: 'See Exercises B and C from: https://ninapro.hevs.ch/instructions/DB2.html'}, + '4 Train, 2 Test', + "NinaProb DB2.", + 'https://ninapro.hevs.ch/', + dataset_folder = dataset_folder) + self.exercise_step = [0,0,0] + self.num_cyberglove_dofs = 22 + self.use_cyberglove = use_cyberglove # needed b/c some files have EMG but no cyberglove + + def prepare_data(self, split = False, subjects_values = None, reps_values = None, classes_values = None): + if subjects_values is None: + subjects_values = [str(i) for i in range(1,41)] + if reps_values is None: + reps_values = [str(i) for i in range(6)] + if classes_values is None: + classes_values = [str(i) for i in range(50)] + + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + raise FileNotFoundError("Please download the NinaProDB2 dataset from: https://ninapro.hevs.ch/instructions/DB2.html") + self.convert_to_compatible() + regex_filters = [ + RegexFilter(left_bound = "/C", right_bound="R", values = classes_values, description='classes'), + RegexFilter(left_bound="R", right_bound=".csv", values=reps_values, description='reps'), + RegexFilter(left_bound="DB2_s", right_bound="/",values=subjects_values, description='subjects') + ] + + if self.use_cyberglove: + # Only want cyberglove files + regex_filters.append(RegexFilter(left_bound="/", right_bound="/C", values=['cyberglove'], description='')) + metadata_fetchers = [ + ColumnFetcher('cyberglove', column_mask=[idx for idx in range(self.num_channels, self.num_channels + self.num_cyberglove_dofs)]) + ] + else: + metadata_fetchers = None + + emg_column_mask = [idx for idx in range(self.num_channels)] # first columns should be EMG + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers, delimiter=",", data_column=emg_column_mask) + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data('reps', [0,1,2,3], fast=True), 'Test': odh.isolate_data('reps', [4,5], fast=True)} + + return data + +class NinaproDB8(Ninapro): + def __init__(self, dataset_folder="NinaProDB8/"): + # NOTE: This expects each subject's data to be in its own zip file, so the data files for one subject end up in a single directory once we unzip them (e.g., DB8_s1) + gestures = { + 0: "rest", + 1: "thumb flexion/extension", + 2: "thumb abduction/adduction", + 3: "index finger flexion/extension", + 4: "middle finger flexion/extension", + 5: "combined ring and little fingers flexion/extension", + 6: "index pointer", + 7: "cylindrical grip", + 8: "lateral grip", + 9: "tripod grip" + } + + super().__init__( + sampling=1111, + num_channels=16, + recording_device='Delsys Trigno', + num_subjects=12, + gestures=gestures, + num_reps=22, + description='Ninapro DB8 - designed for regression of finger kinematics. Ground truth labels are provided via cyberglove data.', + citation='https://ninapro.hevs.ch/', + dataset_folder=dataset_folder + ) + self.exercise_step = [0,10,20] + self.num_cyberglove_dofs = 18 + + def prepare_data(self, split = False, subjects_values = None, reps_values = None, classes_values = None): + if subjects_values is None: + subjects_values = [str(i) for i in range(1,self.num_subjects + 1)] + if reps_values is None: + reps_values = [str(i) for i in range(self.num_reps)] + if classes_values is None: + classes_values = [str(i) for i in range(9)] + + self.convert_to_compatible() + + regex_filters = [ + RegexFilter(left_bound = "/C", right_bound="R", values = classes_values, description='classes'), + RegexFilter(left_bound = "R", right_bound=".csv", values = reps_values, description='reps'), + RegexFilter(left_bound="DB8_s", right_bound="/",values=subjects_values, description='subjects') + ] + metadata_fetchers = [ + ColumnFetcher('labels', column_mask=[idx for idx in range(self.num_channels, self.num_channels + self.num_cyberglove_dofs)]) + ] + emg_column_mask = [idx for idx in range(self.num_channels)] # first columns should be EMG + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers, delimiter=",", data_column=emg_column_mask) + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data('reps', [0, 1, 2, 3], fast=True), 'Test': odh.isolate_data('reps', [4, 5], fast=True)} + return data diff --git a/libemg/_datasets/one_site_biopoint.py b/libemg/_datasets/one_site_biopoint.py new file mode 100644 index 00000000..c92fdf9e --- /dev/null +++ b/libemg/_datasets/one_site_biopoint.py @@ -0,0 +1,48 @@ + +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter +from libemg.feature_extractor import FeatureExtractor +from libemg.utils import * + +class OneSiteBiopoint(Dataset): + def __init__(self, dataset_folder='CIIL_WeaklySupervised/'): + Dataset.__init__(self, + 2000, + 1, + 'SiFi-Labs BioPoint', + 8, + {0: 'Close', 1: 'Open', 2: 'Rest', 3: 'Flexion', 4: 'Extension'}, + 'Six reps', + "A single site, multimodal sensor for gesture recognition", + 'EMBC 2024 - Not Yet Published') + self.url = "https://unbcloud-my.sharepoint.com/:u:/g/personal/ecampbe2_unb_ca/EZG9zfWg_hdJl4De1Clnl34ByTjYqStTB90Nj6EaHkGSnA?e=JQLU7z" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download_via_onedrive(self.url, self.dataset_folder) + + subjects = [str(i) for i in range(0, 8)] + classes_values = [str(i) for i in range(0,17)] + reps_values = [str(i) for i in range(0,6)] + regex_filters = [ + RegexFilter(left_bound = "/S", right_bound="/", values = subjects, description='subjects'), + RegexFilter(left_bound = "R_", right_bound="EMG-bio.csv", values = reps_values, description='reps'), + RegexFilter(left_bound = "C_", right_bound="_R", values = classes_values, description='classes') + ] + odh_s = OfflineDataHandler() + odh_s.get_data(folder_location=self.dataset_folder+"OneSiteBioPoint/", + regex_filters=regex_filters, + delimiter=",") + + + if split: + data = {'All': data, + 'Train': odh_s.isolate_data("reps", list(range(0,3)), fast=True), + 'Test': odh_s.isolate_data("reps", list(range(3,6)), fast=True)} + + return data + + + diff --git a/libemg/_datasets/one_subject_emager.py b/libemg/_datasets/one_subject_emager.py new file mode 100644 index 00000000..2f580437 --- /dev/null +++ b/libemg/_datasets/one_subject_emager.py @@ -0,0 +1,41 @@ +from pathlib import Path + +import numpy as np +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter, FilePackager + + +class OneSubjectEMaGerDataset(Dataset): + def __init__(self, dataset_folder = 'OneSubjectEMaGerDataset/'): + super().__init__( + sampling=1010, + num_channels=64, + recording_device='EMaGer', + num_subjects=1, + gestures={0: 'Hand Close (-) / Hand Open (+)', 1: 'Pronation (-) / Supination (+)'}, + num_reps=5, + description='A simple EMaGer dataset used for regression examples in LibEMG demos.', + citation='N/A' + ) + self.url = 'https://github.com/LibEMG/OneSubjectEMaGerDataset' + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + regex_filters = [ + RegexFilter(left_bound='/', right_bound='/', values=['open-close', 'pro-sup'], description='movements'), + RegexFilter(left_bound='_R_', right_bound='_emg.csv', values=[str(idx) for idx in range(self.num_reps)], description='reps') + ] + package_function = lambda x, y: Path(x).parent.absolute() == Path(y).parent.absolute() + metadata_fetchers = [FilePackager(RegexFilter(left_bound='/', right_bound='.txt', values=['labels'], description='labels'), package_function)] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers) + odh.subjects = [] + odh.subjects = [np.zeros((len(d), 1)) for d in odh.data] + odh.extra_attributes.append('subjects') + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data('reps', [0, 1, 2, 3], fast=True), 'Test': odh.isolate_data('reps', [4], fast=True)} + + return data diff --git a/libemg/_datasets/one_subject_myo.py b/libemg/_datasets/one_subject_myo.py new file mode 100644 index 00000000..d01737f0 --- /dev/null +++ b/libemg/_datasets/one_subject_myo.py @@ -0,0 +1,40 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter +import numpy as np + +class OneSubjectMyoDataset(Dataset): + def __init__(self, dataset_folder="OneSubjectMyoDataset/"): + Dataset.__init__(self, + 200, + 8, + 'Myo Armband', + 1, + {0: 'Close', 1: 'Open', 2: 'Rest', 3: 'Flexion', 4: 'Extension'}, + '6 (4 Train, 2 Test)', + "A simple Myo dataset that is used for some of the LibEMG offline demos.", + 'N/A') + self.url = "https://github.com/libemg/OneSubjectMyoDataset" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + sets_values = ["1","2","3","4","5","6"] + classes_values = ["0","1","2","3","4"] + reps_values = ["0","1"] + regex_filters = [ + RegexFilter(left_bound = "/trial_", right_bound="/", values = sets_values, description='sets'), + RegexFilter(left_bound = "C_", right_bound=".csv", values = classes_values, description='classes'), + RegexFilter(left_bound = "R_", right_bound="_", values = reps_values, description='reps') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") + odh.subjects = [] + odh.subjects = [np.zeros((len(d), 1)) for d in odh.data] + odh.extra_attributes.append('subjects') + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("sets", [0,1,2,3,4], fast=True), 'Test': odh.isolate_data("sets", [5,6], fast=True)} + + return data diff --git a/libemg/_datasets/radmand_lp.py b/libemg/_datasets/radmand_lp.py new file mode 100644 index 00000000..e521b195 --- /dev/null +++ b/libemg/_datasets/radmand_lp.py @@ -0,0 +1,40 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter + +class RadmandLP(Dataset): + def __init__(self, dataset_folder="LimbPosition/"): + Dataset.__init__(self, + 1000, + 6, + 'DelsysTrigno', + 10, + {'N/A': 'Uncertain'}, + '4 Reps (Train), 4 Reps x 15 Positions', + "A large limb position dataset (with 16 static limb positions).", + "https://pubmed.ncbi.nlm.nih.gov/25570046/") + self.url = "https://github.com/libemg/LimbPosition" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + subjects_values = [str(i) for i in range(1,11)] + position_values = ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9", "P10", "P11", "P12", "P13", "P14", "P15", "P16"] + classes_values = [str(i) for i in range(1,9)] + reps_values = ["1","2","3","4"] + + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + regex_filters = [ + RegexFilter(left_bound="/S", right_bound="/",values=subjects_values, description='subjects'), + RegexFilter(left_bound = "_", right_bound="_R", values = position_values, description='positions'), + RegexFilter(left_bound = "_C", right_bound="_P", values = classes_values, description='classes'), + RegexFilter(left_bound = "_R", right_bound=".csv", values = reps_values, description='reps'), + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder + 'RadmandLimbPosition/', regex_filters=regex_filters, delimiter=",") + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("positions", [0], fast=True), 'Test': odh.isolate_data("positions", list(range(1, len(position_values))), fast=True)} + + return data \ No newline at end of file diff --git a/libemg/_datasets/tmr_shirleyryanabilitylab.py b/libemg/_datasets/tmr_shirleyryanabilitylab.py new file mode 100644 index 00000000..8b3eabc9 --- /dev/null +++ b/libemg/_datasets/tmr_shirleyryanabilitylab.py @@ -0,0 +1,63 @@ +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter + +class TMRShirleyRyanAbilityLab(Dataset): + def __init__(self, dataset_folder="TMR/"): + Dataset.__init__(self, + 1000, + 32, + 'Ag/AgCl', + 6, + {0:"HandOpen", + 1:"KeyGrip", + 2:"PowerGrip", + 3:"FinePinchOpened", + 4:"FinePinchClosed", + 5:"TripodOpened", + 6:"TripodClosed", + 7:"Tool", + 8:"Hook", + 9:"IndexPoint", + 10:"ThumbFlexion", + 11:"ThumbExtension", + 12:"ThumbAbduction", + 13:"ThumbAdduction", + 14:"IndexFlexion", + 15:"RingFlexion", + 16:"PinkyFlexion", + 17:"WristSupination", + 18:"WristPronation", + 19:"WristFlexion", + 20:"WristExtension", + 21:"RadialDeviation", + 22:"UlnarDeviation", + 23:"NoMotion"}, + 8, + '6 subjects, 8 reps, 24 motions, pre/post intervention', + "https://pmc.ncbi.nlm.nih.gov/articles/PMC9879512/") + self.url = "https://github.com/LibEMG/TMR_ShirleyRyanAbilityLab" + self.dataset_folder = dataset_folder + + def prepare_data(self, split = False): + subjects_values = ["1","2","3","4","7","10"] + reps_values = [str(i) for i in range(8)] + classes_values = [str(i) for i in range(24)] + intervention_values = ["preTMR","postTMR"] + + print('\nPlease cite: ' + self.citation+'\n') + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + regex_filters = [ + RegexFilter(left_bound="/S", right_bound="/",values=subjects_values, description='subjects'), + RegexFilter(left_bound = "_R", right_bound=".txt", values = reps_values, description='reps'), + RegexFilter(left_bound = "/C", right_bound="_R", values = classes_values, description='classes'), + RegexFilter(left_bound = "/", right_bound="/C", values = intervention_values, description='intervention') + ] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") + data = odh + if split: + data = {'All': odh, 'Train': odh.isolate_data("reps", list(range(4)), fast=True), 'Test': odh.isolate_data("reps", list(range(4,8)), fast=True)} + + return data \ No newline at end of file diff --git a/libemg/_datasets/user_compliance.py b/libemg/_datasets/user_compliance.py new file mode 100644 index 00000000..2cfae0ca --- /dev/null +++ b/libemg/_datasets/user_compliance.py @@ -0,0 +1,52 @@ +from pathlib import Path + +from libemg._datasets.dataset import Dataset +from libemg.data_handler import OfflineDataHandler, RegexFilter, FilePackager + + +class UserComplianceDataset(Dataset): + def __init__(self, dataset_folder = 'UserComplianceDataset/', analysis = 'baseline'): + super().__init__( + sampling=1010, + num_channels=64, + recording_device='EMaGer', + num_subjects=6, + gestures={0: 'Hand Close (-) / Hand Open (+)', 1: 'Pronation (-) / Supination (+)'}, + num_reps=5, + description='Regression dataset used for investigation into user compliance during mimic training.', + citation='https://conferences.lib.unb.ca/index.php/mec/article/view/2507' + ) + self.url = 'https://github.com/LibEMG/UserComplianceDataset' + self.dataset_folder = dataset_folder + self.analysis = analysis + + def prepare_data(self, split = False): + if (not self.check_exists(self.dataset_folder)): + self.download(self.url, self.dataset_folder) + + regex_filters = [ + RegexFilter(left_bound='/', right_bound='/', values=['open-close', 'pro-sup'], description='movements'), + RegexFilter(left_bound='_R_', right_bound='.csv', values=[str(idx) for idx in range(self.num_reps)], description='reps'), + RegexFilter(left_bound='/', right_bound='/', values=['anticipation', 'all-or-nothing', 'baseline'], description='behaviours'), + RegexFilter(left_bound='/', right_bound='/', values=[f"subject-{str(idx).zfill(3)}" for idx in range(1, 7)], description='subjects') + ] + package_function = lambda x, y: Path(x).parent.absolute() == Path(y).parent.absolute() + metadata_fetchers = [FilePackager(RegexFilter(left_bound='/', right_bound='.txt', values=['labels'], description='labels'), package_function)] + odh = OfflineDataHandler() + odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers) + data = odh + if split: + if self.analysis == 'baseline': + data = { + 'All': odh, + 'Train': odh.isolate_data('behaviours', [2], fast=True).isolate_data('reps', [0, 1, 2, 3], fast=True), + 'Test': odh.isolate_data('behaviours', [2], fast=True).isolate_data('reps', [4], fast=True) + } + elif self.analysis == 'all-or-nothing': + data = {'All': odh, 'Train': odh.isolate_data('behaviours', [1], fast=True), 'Test': odh.isolate_data('behaviours', [2], fast=True)} + elif self.analysis == 'anticipation': + data = {'All': odh, 'Train': odh.isolate_data('behaviours', [0], fast=True), 'Test': odh.isolate_data('behaviours', [2], fast=True)} + else: + raise ValueError(f"Unexpected value for analysis. Got: {self.analysis}.") + + return data diff --git a/libemg/_streamers/_oymotion_streamer.py b/libemg/_streamers/_oymotion_streamer.py index 8bdabada..3bd6dedf 100644 --- a/libemg/_streamers/_oymotion_streamer.py +++ b/libemg/_streamers/_oymotion_streamer.py @@ -68,8 +68,11 @@ def start_stream(self): ## BEGIN HARDWARE SPECIFIC CONFIG import platform if platform.system() == 'Linux': - from bluepy import btle - from bluepy.btle import DefaultDelegate, Scanner, Peripheral + try: + from bluepy import btle + from bluepy.btle import DefaultDelegate, Scanner, Peripheral + except: + pass from datetime import datetime, timedelta import struct from enum import Enum @@ -288,31 +291,34 @@ def __init__(self, _cmd, _timeoutTime, _cb): self._cb = _cb if platform.system() == 'Linux': - class MyDelegate(btle.DefaultDelegate): - def __init__(self, gforce): - super().__init__() - self.gforce = gforce - self.bluepy_thread = threading.Thread(target=self.bluepy_handler) - self.bluepy_thread.setDaemon(True) - self.bluepy_thread.start() - - def bluepy_handler(self): - while True: - if not self.gforce.send_queue.empty(): - cmd = self.gforce.send_queue.get_nowait() - self.gforce.cmdCharacteristic.write(cmd) - self.gforce.device.waitForNotifications(1) - - def handleNotification(self, cHandle, data): - # check cHandle - # self.gforce.lock.acquire() - if cHandle == self.gforce.cmdCharacteristic.getHandle(): - self.gforce._onResponse(data) - - # check cHandle - if cHandle == self.gforce.notifyCharacteristic.getHandle(): - self.gforce.handleDataNotification(data, self.gforce.onData) - # self.gforce.lock.release() + try: + class MyDelegate(btle.DefaultDelegate): + def __init__(self, gforce): + super().__init__() + self.gforce = gforce + self.bluepy_thread = threading.Thread(target=self.bluepy_handler) + self.bluepy_thread.setDaemon(True) + self.bluepy_thread.start() + + def bluepy_handler(self): + while True: + if not self.gforce.send_queue.empty(): + cmd = self.gforce.send_queue.get_nowait() + self.gforce.cmdCharacteristic.write(cmd) + self.gforce.device.waitForNotifications(1) + + def handleNotification(self, cHandle, data): + # check cHandle + # self.gforce.lock.acquire() + if cHandle == self.gforce.cmdCharacteristic.getHandle(): + self.gforce._onResponse(data) + + # check cHandle + if cHandle == self.gforce.notifyCharacteristic.getHandle(): + self.gforce.handleDataNotification(data, self.gforce.onData) + # self.gforce.lock.release() + except: + print('Bluepy not installed...') class GForceProfile(): diff --git a/libemg/_streamers/_sifi_bridge_streamer.py b/libemg/_streamers/_sifi_bridge_streamer.py index eea7735b..4d9e39b6 100644 --- a/libemg/_streamers/_sifi_bridge_streamer.py +++ b/libemg/_streamers/_sifi_bridge_streamer.py @@ -22,8 +22,8 @@ class SiFiBridgeStreamer(Process): Parameters ---------- - version : str - The version of the devie ('1_1 for bioarmband, 1_2 or 1_3 for biopoint). + name : str + The name of the device. shared_memory_items : list Shared memory configuration parameters for the streamer in format: ["tag", (size), datatype, Lock()]. @@ -62,7 +62,7 @@ class SiFiBridgeStreamer(Process): """ def __init__(self, - version: str = '1_2', + name: str | None = None, shared_memory_items: list = [], ecg: bool = False, emg: bool = True, @@ -96,7 +96,7 @@ def __init__(self, self.prepare_config_message(ecg, emg, eda, imu, ppg, notch_on, notch_freq, emgfir_on, emg_fir, eda_cfg, fc_lp, fc_hp, freq, streaming) - self.prepare_connect_message(version, mac) + self.prepare_connect_message(name, mac) self.prepare_executable(bridge_version) @@ -139,12 +139,12 @@ def prepare_config_message(self, self.config_message = bytes(self.config_message,"UTF-8") def prepare_connect_message(self, - version: str, + name: str, mac : str): if mac is not None: self.connect_message = '-c ' + str(mac) + '\n' else: - self.connect_message = '-c BioPoint_v' + str(version) + '\n' + self.connect_message = '-c ' + str(name) + '\n' self.connect_message = bytes(self.connect_message,"UTF-8") def prepare_executable(self, diff --git a/libemg/data_handler.py b/libemg/data_handler.py index ec057c98..6b8bf3b2 100644 --- a/libemg/data_handler.py +++ b/libemg/data_handler.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, Sequence +from typing import Callable, Sequence, Any import numpy as np import numpy.typing as npt import pandas as pd @@ -29,7 +29,7 @@ from libemg.utils import get_windows, _get_fn_windows, _get_mode_windows, make_regex class RegexFilter: - def __init__(self, left_bound: str, right_bound: str, values: Sequence, description: str): + def __init__(self, left_bound: str, right_bound: str, values: Sequence[str] | None = None, description: str | None = None, return_value = False): """Filters files based on filenames that match the associated regex pattern and grabs metadata based on the regex pattern. Parameters @@ -41,11 +41,12 @@ def __init__(self, left_bound: str, right_bound: str, values: Sequence, descript values: list The values between the two regexes. description: str - Description of filter - used to name the metadata field. + Description of filter - used to name the metadata field. Pass in an empty string to filter files without storing the values as metadata. """ self.pattern = make_regex(left_bound, right_bound, values) self.values = values self.description = description + self.return_value = return_value def get_matching_files(self, files: Sequence[str]): """Filter out files that don't match the regex pattern and return the matching files. @@ -78,8 +79,20 @@ def get_metadata(self, filename: str): """ # this is how it should work to be the same as the ODH, but we can maybe discuss redoing this so it saves the actual value instead of the indices. might be confusing to pass values to get data but indices to isolate it. also not sure if it needs to be arrays val = re.findall(self.pattern, filename)[0] - idx = self.values.index(val) - return idx + if (self.values is None) or self.return_value: + # We want to store as a number if at all possible to save on memory + try: + return int(val) + except ValueError: + ... + + try: + return float(val) + except ValueError: + # Can't cast to a number, so we return a string + return val + else: + return self.values.index(val) class MetadataFetcher(ABC): @@ -94,8 +107,9 @@ def __init__(self, description: str): self.description = description @abstractmethod - def __call__(self, filename: str, file_data: npt.NDArray, all_files: Sequence[str]): + def __call__(self, filename: str, file_data: npt.NDArray, all_files: Sequence[str]) -> Any: """Fetch metadata. Must return a (N x M) numpy.ndarray, where N is the number of samples in the EMG data and M is the number of columns in the metadata. + If a single value array is returned (0D or 1D), it will be cast to a N x 1 array where all values are the original value. Parameters ---------- @@ -111,36 +125,59 @@ def __call__(self, filename: str, file_data: npt.NDArray, all_files: Sequence[st metadata: np.ndarray Array containing the metadata corresponding to the provided file. """ - raise NotImplementedError("Must implement __call__ method.") + ... class FilePackager(MetadataFetcher): - def __init__(self, regex_filter: RegexFilter, package_function: Callable[[str, str], bool], align_method: str | Callable[[npt.NDArray, npt.NDArray], npt.NDArray] = 'zoom', load = None, column_mask = None): + def __init__(self, regex_filter: RegexFilter, package_function: Callable[[str, str], bool] | Sequence[RegexFilter], + align_method: str | Callable[[npt.NDArray, npt.NDArray], npt.NDArray] = 'zoom', load: Callable[[str], npt.NDArray] | str | None = None, column_mask: Sequence[int] | None = None): """Package data file with another file that contains relevant metadata (e.g., a labels file). Cycles through all files that match the RegexFilter and packages a data file with a metadata file based on a packaging function. Parameters ---------- regex_filter: RegexFilter - Used to find the type of metadata files. - package_function: callable + Used to find the type of metadata files. The description of this RegexFilter is used to assign the name of the field for this metadata in the OfflineDataHandler. + package_function: callable or Sequence[RegexFilter] Function handle used to determine if two files should be packaged together (i.e., found the metadata file that goes with the data file). Takes in the filename of a metadata file and the filename of the data file. Should return True if the files should be packaged together and False if not. + Alternatively, a list of RegexFilters can be passed in and a function will be created that packages files only if the regex metadata of the data filename + and metadata filename match. align_method: str or callable, default='zoom' Method for aligning the samples of the metadata file and data file. Pass in 'zoom' for the metadata file to be zoomed using spline interpolation to the size of the data file or pass in a callable that takes in the metadata and the EMG data and returns the aligned metadata. - load: callable or None, default=None - Custom loading function for metadata file. If None is passed, the metadata is loaded based on the file extension (only .csv and .txt are supported). + load: callable, str, or None, default=None + Determines how metadata file is loaded. If a custom loading function, should take in the filename and return an array. If a string, + it is assumed to be the MRDF key of a .hea file. If None is passed, the metadata is loaded based on the file extension (only .csv and .txt are supported). column_mask: list or None, default=None List of integers corresponding to the indices of the columns that should be extracted from the raw file data. If None is passed, all columns are extracted. """ + assert regex_filter.description is not None, 'RegexFilter must have a description, otherwise metadata will not be stored.' super().__init__(regex_filter.description) self.regex_filter = regex_filter + self.package_filters = None + + if isinstance(package_function, Sequence): + # Create function to ensure metadata matches + self.package_filters = copy.deepcopy(package_function) + package_function = self._match_regex_patterns self.package_function = package_function self.align_method = align_method self.load = load self.column_mask = column_mask + def _match_regex_patterns(self, metadata_file: str, data_file: str): + assert self.package_filters is not None, 'Attempting to match package filters, but None found.' + for filter in self.package_filters: + if len(filter.get_matching_files([metadata_file])) == 0: + # Doesn't match filters + return False + matching_metadata = filter.get_metadata(metadata_file) == filter.get_metadata(data_file) + if not matching_metadata: + return False + return True + + def __call__(self, filename: str, file_data: npt.NDArray, all_files: Sequence[str]): potential_files = self.regex_filter.get_matching_files(all_files) packaged_files = [Path(potential_file) for potential_file in potential_files if self.package_function(potential_file, filename)] @@ -148,13 +185,19 @@ def __call__(self, filename: str, file_data: npt.NDArray, all_files: Sequence[st # I think it's easier to enforce a single file per FilePackager, but we could build in functionality to allow multiple files then just vstack all the data if there's a use case for that. raise ValueError(f"Found {len(packaged_files)} files to be packaged with {filename} when trying to package {self.regex_filter.description} file (1 file should be found). Please check filter and package functions.") packaged_file = packaged_files[0] + suffix = packaged_file.suffix + packaged_file = packaged_file.as_posix() if callable(self.load): # Passed in a custom loading function packaged_file_data = self.load(packaged_file) - elif packaged_file.suffix == '.txt': + elif isinstance(self.load, str): + # Passed in a MRDF key + assert suffix == '.hea', f"Provided string for load parameter, but packaged file doesn't have extension .hea. Please pass in a custom load function and/or ensure the correct file is packaged." + packaged_file_data = (wfdb.rdrecord(packaged_file.replace('.hea', ''))).__getattribute__(self.load) + elif suffix == '.txt': packaged_file_data = np.loadtxt(packaged_file, delimiter=',') - elif packaged_file.suffix == '.csv': + elif suffix == '.csv': packaged_file_data = pd.read_csv(packaged_file) packaged_file_data = packaged_file_data.to_numpy() else: @@ -163,7 +206,7 @@ def __call__(self, filename: str, file_data: npt.NDArray, all_files: Sequence[st # Align with EMG data if self.align_method == 'zoom': zoom_rate = file_data.shape[0] / packaged_file_data.shape[0] - zoom_factor = [zoom_rate if idx == 0 else 1 for idx in range(packaged_file_data.shape[1])] # only align the 0th axis (samples) + zoom_factor = (zoom_rate, 1) # only align the 0th axis (samples) packaged_file_data = zoom(packaged_file_data, zoom=zoom_factor) elif callable(self.align_method): packaged_file_data = self.align_method(packaged_file_data, file_data) @@ -181,7 +224,7 @@ def __call__(self, filename: str, file_data: npt.NDArray, all_files: Sequence[st return packaged_file_data -class ColumnFetch(MetadataFetcher): +class ColumnFetcher(MetadataFetcher): def __init__(self, description: str, column_mask: Sequence[int] | int, values: Sequence | None = None): """Fetch metadata from columns within data file. @@ -260,6 +303,17 @@ def __add__(self, other): setattr(new_odh, self_attribute, new_value) return new_odh + def _append_to_attribute(self, name, value): + if name is None: + # Don't want this data saved to data handler, so skip it + return + if not hasattr(self, name): + setattr(self, name, []) + self.extra_attributes.append(name) + current_value = getattr(self, name) + setattr(self, name, current_value + [value]) + + def get_data(self, folder_location: str, regex_filters: Sequence[RegexFilter], metadata_fetchers: Sequence[MetadataFetcher] | None = None, delimiter: str = ',', mrdf_key: str = 'p_signal', skiprows: int = 0, data_column: Sequence[int] | None = None, downsampling_factor: int | None = None): """Method to collect data from a folder into the OfflineDataHandler object. The relevant data files can be selected based on passing in @@ -293,12 +347,6 @@ def get_data(self, folder_location: str, regex_filters: Sequence[RegexFilter], m ValueError: Raises ValueError if folder_location is not a valid directory. """ - def append_to_attribute(name, value): - if not hasattr(self, name): - setattr(self, name, []) - self.extra_attributes.append(name) - current_value = getattr(self, name) - setattr(self, name, current_value + [value]) if not os.path.isdir(folder_location): raise ValueError(f"Folder location {folder_location} is not a directory.") @@ -338,17 +386,19 @@ def append_to_attribute(name, value): # Fetch metadata from filename for regex_filter in regex_filters: - metadata_idx = regex_filter.get_metadata(file) - metadata = metadata_idx * np.ones((file_data.shape[0], 1), dtype=int) - append_to_attribute(regex_filter.description, metadata) + metadata = regex_filter.get_metadata(file) + self._append_to_attribute(regex_filter.description, metadata) # Fetch remaining metadata for metadata_fetcher in metadata_fetchers: metadata = metadata_fetcher(file, file_data, all_files) - if metadata.ndim == 1: - # Ensure that output is always 2D array - metadata = np.expand_dims(metadata, axis=1) - append_to_attribute(metadata_fetcher.description, metadata) + if isinstance(metadata, np.ndarray): + if metadata.ndim == 0 or metadata.shape[0] == 1: + metadata = metadata.item() + elif metadata.ndim == 1: + # Ensure that output is always 2D array + metadata = np.expand_dims(metadata, axis=1) + self._append_to_attribute(metadata_fetcher.description, metadata) def active_threshold(self, nm_windows, active_windows, active_labels, num_std=3, nm_label=0, silent=True): """Returns an update label list of the active labels for a ramp contraction. @@ -395,6 +445,13 @@ def parse_windows(self, window_size, window_increment, metadata_operations=None) The number of samples in a window. window_increment: int The number of samples that advances before next window. + metadata_operations: dict or None (optional),default=None + Specifies which operations should be performed on metadata attributes when performing windowing. By default, + all metadata is stored as its mode in a window. To change this behaviour, specify the metadata attribute as the key and + the operation as the value in the dictionary. The operation (value) should either be an accepted string (mean, median, last_sample) or + a function handle that takes in an ndarray of size (window_size, ) and returns a single value to represent the metadata for that window. Passing in a string + will map from that string to the specified operation. The windowing of only the attributes specified in this dictionary will be modified - all other + attributes will default to the mode. If None, all attributes default to the mode. Defaults to None. Returns ---------- @@ -407,34 +464,41 @@ def parse_windows(self, window_size, window_increment, metadata_operations=None) return self._parse_windows_helper(window_size, window_increment, metadata_operations) def _parse_windows_helper(self, window_size, window_increment, metadata_operations): - metadata_ = {} + common_metadata_operations = { + 'mean': np.mean, + 'median': np.median, + 'last_sample': lambda x: x[-1] + } + window_data = [] + metadata = {k: [] for k in self.extra_attributes} for i, file in enumerate(self.data): # emg data windowing - windows = get_windows(file,window_size,window_increment) - if "windows_" in locals(): - windows_ = np.concatenate((windows_, windows)) - else: - windows_ = windows - # metadata windowing + window_data.append(get_windows(file,window_size,window_increment)) + for k in self.extra_attributes: - if type(getattr(self,k)[i]) != np.ndarray: - file_metadata = np.ones((windows.shape[0])) * getattr(self, k)[i] + file_attribute = getattr(self, k)[i] + if not isinstance(file_attribute, np.ndarray): + file_metadata = np.full(window_data[-1].shape[0], fill_value=file_attribute) else: if metadata_operations is not None: if k in metadata_operations.keys(): # do the specified operation - file_metadata = _get_fn_windows(getattr(self,k)[i], window_size, window_increment, metadata_operations[k]) + operation = metadata_operations[k] + + if isinstance(operation, str): + try: + operation = common_metadata_operations[operation] + except KeyError as e: + raise KeyError(f"Unexpected metadata operation string. Please pass in a function or an accepted string {tuple(common_metadata_operations.keys())}. Got: {operation}.") + file_metadata = _get_fn_windows(file_attribute, window_size, window_increment, operation) else: - file_metadata = _get_mode_windows(getattr(self,k)[i], window_size, window_increment) + file_metadata = _get_mode_windows(file_attribute, window_size, window_increment) else: - file_metadata = _get_mode_windows(getattr(self,k)[i], window_size, window_increment) - if k not in metadata_.keys(): - metadata_[k] = file_metadata - else: - metadata_[k] = np.concatenate((metadata_[k], file_metadata)) - + file_metadata = _get_mode_windows(file_attribute, window_size, window_increment) + + metadata[k].append(file_metadata) - return windows_, metadata_ + return np.vstack(window_data), {k: np.concatenate(metadata[k], axis=0) for k in metadata.keys()} def isolate_channels(self, channels): @@ -461,7 +525,7 @@ def isolate_channels(self, channels): new_odh.data[i] = new_odh.data[i][:,channels] return new_odh - def isolate_data(self, key, values): + def isolate_data(self, key, values, fast=False): """Entry point for isolating a single key of data within the offline data handler. First, error checking is performed within this method, then if it passes, the isolate_data_helper is called to make a new OfflineDataHandler that contains only that data. @@ -471,6 +535,8 @@ def isolate_data(self, key, values): The metadata key that will be used to filter (e.g., "subject", "rep", "class", "set", whatever you'd like). values: list A list of values that you want to isolate. (e.g. [0,1,2,3]). Indexing starts at 0. + fast: Boolean (default=False) + If true, it iterates over the median value for each EMG element. This should be used when parsing on things like reps, subjects, classes, etc. Returns ---------- @@ -479,47 +545,35 @@ def isolate_data(self, key, values): """ assert key in self.extra_attributes assert type(values) == list - return self._isolate_data_helper(key,values) + return self._isolate_data_helper(key,values,fast) - def _isolate_data_helper(self, key, values): + def _isolate_data_helper(self, key, values,fast): new_odh = OfflineDataHandler() setattr(new_odh, "extra_attributes", self.extra_attributes) key_attr = getattr(self, key) - - # if these end up being ndarrays, it means that the metadata was IN the csv file. - - if type(key_attr[0]) == np.ndarray: - # for every file (list element) - data = [] - for f in range(len(key_attr)): - # get the keep_mask - keep_mask = list([i in values for i in key_attr[f]]) - # append the valid data - if self.data[f][keep_mask,:].shape[0]> 0: - data.append(self.data[f][keep_mask,:]) - setattr(new_odh, "data", data) + for e in self.extra_attributes: + setattr(new_odh, e, []) + + for file_idx in range(len(key_attr)): + file_data = self.data[file_idx] + file_metadata = key_attr[file_idx] + if isinstance(file_metadata, np.ndarray): + keep_mask = np.full(file_metadata.shape[0], fill_value=False) + for value in values: + keep_mask = keep_mask | (file_metadata == value) + else: + keep = file_metadata in values + keep_mask = np.full(file_data.shape[0], fill_value=keep) + + if file_data[keep_mask].shape[0] > 0: + new_odh.data.append(file_data[keep_mask]) + for e in self.extra_attributes: + new_metadata = getattr(self, e)[file_idx] + if isinstance(new_metadata, np.ndarray): + new_metadata = new_metadata[keep_mask] + new_odh._append_to_attribute(e, new_metadata) + - for k in self.extra_attributes: - key_value = getattr(self, k) - if type(key_value[0]) == np.ndarray: - # the other metadata that is in the csv file should be sliced the same way as the ndarray - key = [] - for f in range(len(key_attr)): - keep_mask = list([i in values for i in key_attr[f]]) - if key_value[f][keep_mask,:].shape[0]>0: - key.append(key_value[f][keep_mask,:]) - setattr(new_odh, k, key) - - else: - assert False # we should never get here - # # if the other metadata was not in the csv file (i.e. subject label in filename but classes in csv), then just keep it - # setattr(new_odh, k, key_value) - else: - assert False # we should never get here - # keep_mask = list([i in values for i in key_attr]) - # setattr(new_odh, "data", list(compress(self.data, keep_mask))) - # for k in self.extra_attributes: - # setattr(new_odh, k,list(compress(getattr(self, k), keep_mask))) return new_odh def visualize(): @@ -536,13 +590,17 @@ class OnlineDataHandler(DataHandler): ---------- shared_memory_items: Object The shared memory object returned from the streamer. + channel_mask: list or None (optional), default=None + Mask of active channels to use online. Allows certain channels to be ignored when streaming in real-time. If None, all channels are used. + Defaults to None. """ - def __init__(self, shared_memory_items): + def __init__(self, shared_memory_items, channel_mask = None): self.shared_memory_items = shared_memory_items self.prepare_smm() self.log_signal = Event() self.visualize_signal = Event() self.fi = None + self.channel_mask = channel_mask def prepare_smm(self): self.modalities = [] @@ -584,6 +642,17 @@ def install_filter(self, fi): """ self.fi = fi + def install_channel_mask(self, mask): + """Install a channel mask to isolate certain channels for online streaming. + + Parameters + ---------- + mask: list or None (optional), default=None + Mask of active channels to use online. Allows certain channels to be ignored when streaming in real-time. If None, all channels are used. + Defaults to None. + """ + self.channel_mask = mask + def analyze_hardware(self, analyze_time=10): """Analyzes several metrics from the hardware: @@ -771,7 +840,8 @@ def extract_data(): # Extract features along each channel windows = data[np.newaxis].transpose(0, 2, 1) # add axis and tranpose to convert to (windows x channels x samples) fe = FeatureExtractor() - feature_set_dict = fe.extract_features(feature_list, windows) + feature_set_dict = fe.extract_features(feature_list, windows, array=False) + assert isinstance(feature_set_dict, dict), f"Expected dictionary of features. Got: {type(feature_set_dict)}." if remap_function is not None: # Remap raw data to image format for key in feature_set_dict: @@ -949,6 +1019,8 @@ def get_data(self, N=0, filter=True): val[mod] = data[:N,:] else: val[mod] = data[:,:] + if self.channel_mask is not None: + val[mod] = val[mod][:, self.channel_mask] count[mod] = self.smm.get_variable(mod+"_count") return val,count @@ -1034,4 +1106,4 @@ def _check_streaming(self, timeout=15): def start_listening(self): print("LibEMG>v1.0 no longer requires online_data_handler.start_listening().\nThis is deprecated.") - pass + pass \ No newline at end of file diff --git a/libemg/datasets.py b/libemg/datasets.py index 04f67e8c..72161a04 100644 --- a/libemg/datasets.py +++ b/libemg/datasets.py @@ -1,490 +1,196 @@ -import os -import numpy as np -import zipfile -import scipy.io as sio -from libemg.data_handler import ColumnFetch, MetadataFetcher, OfflineDataHandler, RegexFilter, FilePackager -from libemg.utils import make_regex -from glob import glob -from os import walk -from pathlib import Path -from datetime import datetime -# this assumes you have git downloaded (not pygit, but the command line program git) - -class Dataset: - def __init__(self, save_dir='.', redownload=False): - self.save_dir = save_dir - self.redownload=redownload - - def download(self, url, dataset_name): - clone_command = "git clone " + url + " " + dataset_name - os.system(clone_command) +from libemg._datasets._3DC import _3DCDataset +from libemg._datasets.one_subject_myo import OneSubjectMyoDataset +from libemg._datasets.one_subject_emager import OneSubjectEMaGerDataset +from libemg._datasets.emg_epn612 import EMGEPN612 +from libemg._datasets.ciil import CIIL_MinimalData, CIIL_ElectrodeShift, CIIL_WeaklySupervised +from libemg._datasets.grab_myo import GRABMyoBaseline, GRABMyoCrossDay +from libemg._datasets.continous_transitions import ContinuousTransitions +from libemg._datasets.nina_pro import NinaproDB2, NinaproDB8 +from libemg._datasets.myodisco import MyoDisCo +from libemg._datasets.user_compliance import UserComplianceDataset +from libemg._datasets.fors_emg import FORSEMG +from libemg._datasets.radmand_lp import RadmandLP +from libemg._datasets.fougner_lp import FougnerLP +from libemg._datasets.intensity import ContractionIntensity +from libemg._datasets.hyser import Hyser1DOF, HyserNDOF, HyserRandom, HyserPR +from libemg._datasets.kaufmann_md import KaufmannMD +from libemg._datasets.tmr_shirleyryanabilitylab import TMRShirleyRyanAbilityLab +from libemg._datasets.one_site_biopoint import OneSiteBiopoint +from libemg.feature_extractor import FeatureExtractor +from libemg.emg_predictor import EMGClassifier, EMGRegressor +from libemg.offline_metrics import OfflineMetrics +import pickle +import time + +def get_dataset_list(type='CLASSIFICATION'): + """Gets a list of all available datasets. + + Parameters + ---------- + type: str (default='CLASSIFICATION') + The type of datasets to return. Valid Options: 'CLASSIFICATION', 'REGRESSION', and 'ALL'. - def remove_dataset(self, dataset_folder): - remove_command = "rm -rf " + dataset_folder - os.system(remove_command) - - def check_exists(self, dataset_folder): - return os.path.exists(dataset_folder) - - def prepare_data(self, format=OfflineDataHandler): - pass - - -class _3DCDataset(Dataset): - def __init__(self, save_dir='.', redownload=False, dataset_name="_3DCDataset"): - Dataset.__init__(self, save_dir, redownload) - self.url = "https://github.com/libemg/3DCDataset" - self.dataset_name = dataset_name - self.dataset_folder = os.path.join(self.save_dir , self.dataset_name) - self.class_list = ["Neutral", "Radial Deviation", "Wrist Flexion", "Ulnar Deviation", "Wrist Extension", "Supination", - "Pronation", "Power Grip", "Open Hand", "Chuck Grip", "Pinch Grip"] - - if (not self.check_exists(self.dataset_folder)): - self.download(self.url, self.dataset_folder) - elif (self.redownload): - self.remove_dataset(self.dataset_folder) - self.download(self.url, self.dataset_folder) - - - - def prepare_data(self, format=OfflineDataHandler, subjects_values = [str(i) for i in range(1,23)], - sets_values = ["train", "test"], - reps_values = ["0","1","2","3"], - classes_values = [str(i) for i in range(11)]): - if format == OfflineDataHandler: - regex_filters = [ - RegexFilter(left_bound = "/", right_bound="/EMG", values = sets_values, description='sets'), - RegexFilter(left_bound = "_", right_bound=".txt", values = classes_values, description='classes'), - RegexFilter(left_bound = "EMG_gesture_", right_bound="_", values = reps_values, description='reps'), - RegexFilter(left_bound="Participant", right_bound="/",values=subjects_values, description='subjects') - ] - odh = OfflineDataHandler() - odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") - return odh - -class Ninapro(Dataset): - def __init__(self, save_dir='.', dataset_name="Ninapro"): - # downloading the Ninapro dataset is not supported (no permission given from the authors)' - # however, you can download it from http://ninapro.hevs.ch/DB8 - # the subject zip files should be placed at: /NinaproDB8/DB8_s#.zip - Dataset.__init__(self, save_dir) - self.dataset_name = dataset_name - self.dataset_folder = os.path.join(self.save_dir , self.dataset_name, "") - self.exercise_step = [] + Returns + ---------- + dictionary + A dictionary with the all available datasets and their respective classes. + """ + type = type.upper() + if type not in ['CLASSIFICATION', 'REGRESSION', 'WEAKLYSUPERVISED', 'ALL']: + print('Valid Options for type parameter: \'CLASSIFICATION\', \'REGRESSION\', or \'ALL\'.') + return {} - def convert_to_compatible(self): - # get the zip files (original format they're downloaded in) - zip_files = find_all_files_of_type_recursively(self.dataset_folder,".zip") - # unzip the files -- if any are there (successive runs skip this) - for zip_file in zip_files: - with zipfile.ZipFile(zip_file, 'r') as zip_ref: - zip_ref.extractall(zip_file[:-4]+'/') - os.remove(zip_file) - # get the mat files (the files we want to convert to csv) - mat_files = find_all_files_of_type_recursively(self.dataset_folder,".mat") - for mat_file in mat_files: - self.convert_to_csv(mat_file) + classification = { + 'OneSubjectMyo': OneSubjectMyoDataset, + '3DC': _3DCDataset, + 'CIIL_MinimalData': CIIL_MinimalData, + 'CIIL_ElectrodeShift': CIIL_ElectrodeShift, + 'GRABMyoBaseline': GRABMyoBaseline, + 'GRABMyoCrossDay': GRABMyoCrossDay, + 'ContinuousTransitions': ContinuousTransitions, + 'NinaProDB2': NinaproDB2, + 'FORS-EMG': FORSEMG, + 'EMGEPN612': EMGEPN612, + 'ContractionIntensity': ContractionIntensity, + 'RadmandLP': RadmandLP, + 'FougnerLP': FougnerLP, + 'KaufmannMD': KaufmannMD, + 'TMRShirleyRyanAbilityLab' : TMRShirleyRyanAbilityLab, + 'HyserPR': HyserPR, + 'OneSiteBioPoint': OneSiteBiopoint + } + + regression = { + 'OneSubjectEMaGer': OneSubjectEMaGerDataset, + 'NinaProDB8': NinaproDB8, + 'Hyser1DOF': Hyser1DOF, + 'HyserNDOF': HyserNDOF, + 'HyserRandom': HyserRandom, + 'UserCompliance': UserComplianceDataset + } + + weaklysupervised = { + 'CIILWeaklySupervised': CIIL_WeaklySupervised + } - def convert_to_csv(self, mat_file): - # read the mat file - mat_file = mat_file.replace("\\", "/") - mat_dir = mat_file.split('/') - mat_dir = os.path.join(*mat_dir[:-1],"") - mat = sio.loadmat(mat_file) - # get the data - exercise = int(mat_file.split('_')[3][1]) - exercise_offset = self.exercise_step[exercise-1] # 0 reps already included - data = mat['emg'] - restimulus = mat['restimulus'] - rerepetition = mat['rerepetition'] - if data.shape[0] != restimulus.shape[0]: # this happens in some cases - min_shape = min([data.shape[0], restimulus.shape[0]]) - data = data[:min_shape,:] - restimulus = restimulus[:min_shape,] - rerepetition = rerepetition[:min_shape,] - # remove 0 repetition - collection buffer - remove_mask = (rerepetition != 0).squeeze() - data = data[remove_mask,:] - restimulus = restimulus[remove_mask] - rerepetition = rerepetition[remove_mask] - # important little not here: - # the "rest" really is only the rest between motions, not a dedicated rest class. - # there will be many more rest repetitions (as it is between every class) - # so usually we really care about classifying rest as its important (most of the time we do nothing) - # but for this dataset it doesn't make sense to include (and not its just an offline showcase of the library) - # I encourage you to plot the restimulus to see what I mean. -> plt.plot(restimulus) - # so we remove the rest class too - remove_mask = (restimulus != 0).squeeze() - data = data[remove_mask,:] - restimulus = restimulus[remove_mask] - rerepetition = rerepetition[remove_mask] - tail = 0 - while tail < data.shape[0]-1: - rep = rerepetition[tail][0] # remove the 1 offset (0 was the collection buffer) - motion = restimulus[tail][0] # remove the 1 offset (0 was between motions "rest") - # find head - head = np.where(rerepetition[tail:] != rep)[0] - if head.shape == (0,): # last segment of data - head = data.shape[0] -1 - else: - head = head[0] + tail - # downsample to 1kHz from 2kHz using decimation - data_for_file = data[tail:head,:] - data_for_file = data_for_file[::2, :] - # write to csv - csv_file = mat_dir + 'C' + str(motion-1) + 'R' + str(rep-1 + exercise_offset) + '.csv' - np.savetxt(csv_file, data_for_file, delimiter=',') - tail = head - os.remove(mat_file) - -class NinaproDB8(Ninapro): - def __init__(self, save_dir='.', dataset_name="NinaproDB8"): - Ninapro.__init__(self, save_dir, dataset_name) - self.class_list = ["Thumb Flexion/Extension", "Thumb Abduction/Adduction", "Index Finger Flexion/Extension", "Middle Finger Flexion/Extension", "Combined Ring and Little Fingers Flexion/Extension", - "Index Pointer", "Cylindrical Grip", "Lateral Grip", "Tripod Grip"] - self.exercise_step = [0,10,20] + if type == 'CLASSIFICATION': + return classification + elif type == 'REGRESSION': + return regression + elif type == "WEAKLYSUPERVISED": + return weaklysupervised + else: + # Concatenate all datasets + classification.update(regression) + classification.update(weaklysupervised) + return classification + +def get_dataset_info(dataset): + """Prints out the information about a certain dataset. + + Parameters + ---------- + dataset: string + The name of the dataset you want the information of. + """ + if dataset in get_dataset_list(): + get_dataset_list()[dataset]().get_info() + else: + print("ERROR: Invalid dataset name") + +def evaluate(model, window_size, window_inc, feature_list=['MAV'], feature_dic={}, included_datasets=['OneSubjectMyo', '3DC'], output_file='out.pkl', regression=False, metrics=['CA']): + """Evaluates an algorithm against all included datasets. + + Parameters + ---------- + window_size: int + The window size (**in ms**). + window_inc: int + The window increment (**in ms**). + feature_list: list (default=['MAV']) + A list of features. + feature_dic: dic (default={}) + A dictionary of parameters for the passed in features. + included_dataasets: list (str) or list (DataSets) + The name of the datasets you want to evaluate your model on. Either pass in strings (e.g., '3DC') for names or the dataset objects (e.g., _3DCDataset()). + output_file: string (default='out.pkl') + The name of the directory you want to incrementally save the results to (it will be a pickle file). + regression: boolean (default=False) + If True, will create an EMGRegressor object. Otherwise creates an EMGClassifier object. + metrics: list (default=['CA']/['MSE']) + The metrics to extract from each dataset. + Returns + ---------- + dictionary + A dictionary with a set of accuracies for different datasets + """ + + # -------------- Setup ------------------- + if metrics == ['CA'] and regression: + metrics = ['MSE'] + + metadata_operations = None + label_val = 'classes' + if regression: + metadata_operations = {'labels': 'last_sample'} + label_val = 'labels' + + om = OfflineMetrics() + + # --------------- Run ----------------- + accuracies = {} + for d in included_datasets: + print('Evaluating ' + d + ' dataset...') + if isinstance(d, str): + dataset = get_dataset_list('ALL')[d]() + else: + dataset = d - def prepare_data(self, format=OfflineDataHandler, subjects_values = [str(i) for i in range(1,13)], - reps_values = [str(i) for i in range(22)], - classes_values = [str(i) for i in range(9)]): + if isinstance(dataset, EMGEPN612): + print('EMGEPN612 Dataset is meant for cross user modelling... Skipping.') + continue - if format == OfflineDataHandler: - regex_filters = [ - RegexFilter(left_bound = "/C", right_bound="R", values = classes_values, description='classes'), - RegexFilter(left_bound = "R", right_bound=".csv", values = reps_values, description='reps'), - RegexFilter(left_bound="DB8_s", right_bound="/",values=subjects_values, description='subjects') - ] - odh = OfflineDataHandler() - odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") - return odh - -class NinaproDB2(Ninapro): - def __init__(self, save_dir='.', dataset_name="NinaproDB2"): - Ninapro.__init__(self, save_dir, dataset_name) - self.class_list = ["TODO"] - self.exercise_step = [0,0,0] - - def prepare_data(self, format=OfflineDataHandler, subjects_values = [str(i) for i in range(1,41)], - reps_values = [str(i) for i in range(6)], - classes_values = [str(i) for i in range(50)]): + data = dataset.prepare_data(split=True) - if format == OfflineDataHandler: - regex_filters = [ - RegexFilter(left_bound = "/C", right_bound="R", values = classes_values, description='classes'), - RegexFilter(left_bound = "R", right_bound=".csv", values = reps_values, description='reps'), - RegexFilter(left_bound="DB2_s", right_bound="/",values=subjects_values, description='subjects') - ] - odh = OfflineDataHandler() - odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") - return odh - -# given a directory, return a list of files in that directory matching a format -# can be nested -# this is just a handly utility -def find_all_files_of_type_recursively(dir, terminator): - files = os.listdir(dir) - file_list = [] - for file in files: - if file.endswith(terminator): - file_list.append(dir+file) - else: - if os.path.isdir(dir+file): - file_list += find_all_files_of_type_recursively(dir+file+'/',terminator) - return file_list - - -class OneSubjectMyoDataset(Dataset): - def __init__(self, save_dir='.', redownload=False, dataset_name="OneSubjectMyoDataset"): - Dataset.__init__(self, save_dir, redownload) - self.url = "https://github.com/libemg/OneSubjectMyoDataset" - self.dataset_name = dataset_name - self.dataset_folder = os.path.join(self.save_dir , self.dataset_name) - - if (not self.check_exists(self.dataset_folder)): - self.download(self.url, self.dataset_folder) - elif (self.redownload): - self.remove_dataset(self.dataset_folder) - self.download(self.url, self.dataset_folder) - - def prepare_data(self, format=OfflineDataHandler): - if format == OfflineDataHandler: - sets_values = ["1","2","3","4","5","6"] - classes_values = ["0","1","2","3","4"] - reps_values = ["0","1"] - regex_filters = [ - RegexFilter(left_bound = "/trial_", right_bound="/", values = sets_values, description='sets'), - RegexFilter(left_bound = "C_", right_bound=".csv", values = classes_values, description='classes'), - RegexFilter(left_bound = "R_", right_bound="_", values = reps_values, description='reps') - ] - odh = OfflineDataHandler() - odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, delimiter=",") - return odh - - -class _SessionFetcher(MetadataFetcher): - def __init__(self): - super().__init__('sessions') - - def __call__(self, filename, file_data, all_files): - def split_filename(f): - # Split date and name into separate variables - date_idx = f.find('2018') - date = datetime.strptime(Path(f[date_idx:]).stem, '%Y-%m-%d-%H-%M-%S-%f') - description = f[:date_idx] - return date, description - - data_file_date, data_file_description = split_filename(filename) - - # Grab the other file of a different date. Return the index of which session it is - same_subject_files = [f for f in all_files if data_file_description in f] - file_dates = [split_filename(subject_filename)[0] for subject_filename in same_subject_files] - file_dates.sort() - session_idx = file_dates.index(data_file_date) - return session_idx * np.ones((file_data.shape[0], 1), dtype=int) - - -class _RepFetcher(ColumnFetch): - def __call__(self, filename, file_data, all_files): - column_data = super().__call__(filename, file_data, all_files) + train_data = data['Train'] + test_data = data['Test'] - # Get rep transitions - diff = np.diff(column_data, axis=0) - rep_end_row_mask, rep_end_col_mask = np.nonzero((diff < 0) & (column_data[1:] == 0)) - unique_rep_end_row_mask = np.unique(rep_end_row_mask) # remove duplicate start indices (for combined movements) - # rest_end_row_mask = np.nonzero(np.diff(np.nonzero(column_data == 0)[0]) > 1)[0] - # rest_end_row_mask = np.nonzero(np.diff(np.nonzero(np.all(column_data == 0, axis=1))[0]) > 1)[0] - # unique_rep_end_row_mask = np.concatenate((unique_rep_end_row_mask, rest_end_row_mask)) - # unique_rep_end_row_mask = np.sort(unique_rep_end_row_mask) - - - # Populate metadata array - metadata = np.empty((column_data.shape[0], 1), dtype=np.int16) - rep_counters = [0 for _ in range(5)] # 5 different press types - previous_rep_start = 0 - for idx, rep_start in enumerate(unique_rep_end_row_mask): - movement_idx = 4 if np.sum(rep_end_row_mask == rep_start) > 1 else rep_end_col_mask[idx] # if multiple columns are nonzero then it's a combined movement - rep = rep_counters[movement_idx] - metadata[previous_rep_start:rep_start] = rep - rep_counters[movement_idx] += 1 - previous_rep_start = rep_start - - # Fill in final samples - metadata[rep_start:] = rep - - return metadata - - -class PutEMGForceDataset(Dataset): - def __init__(self, save_dir = '.', dataset_name = 'PutEMGForceDataset', data_filetype = None): - """Dataset wrapper for putEMG-Force dataset. Used for regression of finger forces. - - Parameters - ---------- - save_dir : str, default='.' - Base data directory. - dataset_name : str, default='PutEMGForceDataset' - Name of dataset. Looks for dataset in filepath created by appending save_dir and dataset_name. - data_filetype : list or None, default=None - Type of data file to use. Accepted values are 'repeats_long', 'repeats_short', 'sequential', or any combination of those. If None is passed, all will be used. - """ - # TODO: Implement downloading dataset using .sh or .py file - super().__init__(save_dir) - self.dataset_name = dataset_name - self.dataset_folder = os.path.join(self.save_dir, self.dataset_name) - if data_filetype is None: - data_filetype = ['repeats_short', 'repeats_long', 'sequential'] - elif not isinstance(data_filetype, list): - data_filetype = [data_filetype] - self.data_filetype = data_filetype - - def prepare_data(self, format=OfflineDataHandler, subjects = None, sessions = None, reps = None, labels = 'forces', label_dof_mask = None): - if subjects is None: - subjects = [str(idx).zfill(2) for idx in range(60)] - - if labels == 'forces': - column_mask = np.arange(25, 35) - elif labels == 'trajectories': - column_mask = np.arange(36, 40) - else: - raise ValueError(f"Expected either 'forces' or trajectories' for labels parameter, but received {labels}.") - - if label_dof_mask is not None: - column_mask = column_mask[label_dof_mask] - - if format == OfflineDataHandler: - regex_filters = [ - RegexFilter(left_bound='/emg_force-', right_bound='-', values=subjects, description='subjects'), - RegexFilter(left_bound='-', right_bound='-', values=self.data_filetype, description='data_filetype'), - ] - metadata_fetchers = [ - _SessionFetcher(), - ColumnFetch('labels', column_mask), - _RepFetcher('reps', list(range(36, 40))) - ] - odh = OfflineDataHandler() - odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers, delimiter=',', skiprows=1, data_column=list(range(1, 25))) - if sessions is not None: - odh = odh.isolate_data('sessions', sessions) - if reps is not None: - odh = odh.isolate_data('reps', reps) - return odh - - -class OneSubjectEMaGerDataset(Dataset): - def __init__(self, save_dir = '.', redownload = False, dataset_name = 'OneSubjectEMaGerDataset'): - super().__init__(save_dir, redownload) - self.url = 'https://github.com/LibEMG/OneSubjectEMaGerDataset' - self.dataset_name = dataset_name - self.dataset_folder = os.path.join(self.save_dir, self.dataset_name) - - if (not self.check_exists(self.dataset_folder)): - self.download(self.url, self.dataset_folder) - elif (self.redownload): - self.remove_dataset(self.dataset_folder) - self.download(self.url, self.dataset_folder) - - def prepare_data(self, format=OfflineDataHandler): - if format == OfflineDataHandler: - regex_filters = [ - RegexFilter(left_bound='/', right_bound='/', values=['open-close', 'pro-sup'], description='movements'), - RegexFilter(left_bound='_R_', right_bound='_emg.csv', values=[str(idx) for idx in range(5)], description='reps') - ] - package_function = lambda x, y: Path(x).parent.absolute() == Path(y).parent.absolute() - metadata_fetchers = [FilePackager(RegexFilter(left_bound='/', right_bound='.txt', values=['labels'], description='labels'), package_function)] - odh = OfflineDataHandler() - odh.get_data(folder_location=self.dataset_folder, regex_filters=regex_filters, metadata_fetchers=metadata_fetchers) - return odh + accs = [] + for s in range(0, dataset.num_subjects): + print(str(s) + '/' + str(dataset.num_subjects) + ' completed.') + s_train_dh = train_data.isolate_data('subjects', [s]) + s_test_dh = test_data.isolate_data('subjects', [s]) + train_windows, train_meta = s_train_dh.parse_windows(int(dataset.sampling/1000 * window_size), int(dataset.sampling/1000 * window_inc), metadata_operations=metadata_operations) + test_windows, test_meta = s_test_dh.parse_windows(int(dataset.sampling/1000 * window_size), int(dataset.sampling/1000 * window_inc), metadata_operations=metadata_operations) -# class GRABMyo(Dataset): -# def __init__(self, save_dir='.', redownload=False, subjects=list(range(1,44)), sessions=list(range(1,4)), dataset_name="GRABMyo"): -# Dataset.__init__(self, save_dir, redownload) -# self.url = "https://physionet.org/files/grabmyo/1.0.2/" -# self.dataset_name = dataset_name -# self.dataset_folder = os.path.join(self.save_dir , self.dataset_name) -# self.subjects = subjects -# self.sessions = sessions - -# if (not self.check_exists(self.dataset_folder)): -# self.download_data() -# elif (self.redownload): -# self.remove_dataset(self.dataset_folder) -# self.download_data() -# else: -# print("Data Already Downloaded.") - -# def download_data(self): -# curl_command = "curl --create-dirs" + " -O --output-dir " + str(self.dataset_folder) + "/ " -# # Download files -# print("Starting download...") -# files = ['readme.txt', 'subject-info.csv', 'MotionSequence.txt'] -# for f in files: -# os.system(curl_command + self.url + f) -# for session in self.sessions: -# curl_command = "curl --create-dirs" + " -O --output-dir " + str(self.dataset_folder) + "/" + "Session" + str(session) + "/ " -# for p in self.subjects: -# for t in range(1,8): -# for g in range(1,18): -# endpoint = self.url + "Session" + str(session) + "/session" + str(session) + "_participant" + str(p) + "/session" + str(session) + "_participant" + str(p) + "_gesture" + str(g) + "_trial" + str(t) -# os.system(curl_command + endpoint + '.hea') -# os.system(curl_command + endpoint + '.dat') -# print("Download complete.") - -# def prepare_data(self, format=OfflineDataHandler, subjects=[str(i) for i in range(1,44)], sessions=["1","2","3"]): -# if format == OfflineDataHandler: -# sets_regex = make_regex(left_bound = "session", right_bound="_", values = sessions) -# classes_values = ["1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17"] -# classes_regex = make_regex(left_bound = "_gesture", right_bound="_", values = classes_values) -# reps_values = ["1","2","3","4","5","6","7"] -# reps_regex = make_regex(left_bound = "trial", right_bound=".hea", values = reps_values) -# subjects_regex = make_regex(left_bound="participant", right_bound="_",values=subjects) -# dic = { -# "sessions": sessions, -# "sessions_regex": sets_regex, -# "reps": reps_values, -# "reps_regex": reps_regex, -# "classes": classes_values, -# "classes_regex": classes_regex, -# "subjects": subjects, -# "subjects_regex": subjects_regex -# } -# odh = OfflineDataHandler() -# odh.get_data(folder_location=self.dataset_folder, filename_dic=dic, delimiter=",") -# return odh - -# def print_info(self): -# print('Reference: https://www.physionet.org/content/grabmyo/1.0.2/') -# print('Name: ' + self.dataset_name) -# print('Gestures: 17') -# print('Trials: 7') -# print('Time Per Rep: 5s') -# print('Subjects: 43') -# print("Forearm EMG (16): Columns 0-15\nWrist EMG (12): 18-23 and 26-31\nUnused (4): 16,23,24,31") + fe = FeatureExtractor() + train_feats = fe.extract_features(feature_list, train_windows, feature_dic=feature_dic) + test_feats = fe.extract_features(feature_list, test_windows, feature_dic=feature_dic) + ds = { + 'training_features': train_feats, + 'training_labels': train_meta[label_val] + } -# class NinaDB1(Dataset): -# def __init__(self, dataset_dir, subjects): -# Dataset.__init__(self, dataset_dir) -# self.dataset_folder = dataset_dir -# self.subjects = subjects - -# if (not self.check_exists(self.dataset_folder)): -# print("The dataset does not currently exist... Please download it from: http://ninaweb.hevs.ch/data1") -# exit(1) -# else: -# filenames = next(walk(self.dataset_folder), (None, None, []))[2] -# if not any("csv" in f for f in filenames): -# self.setup(filenames) -# print("Extracted and set up repo.") -# self.prepare_data() - -# def setup(self, filenames): -# for f in filenames: -# if "zip" in f: -# file_path = os.path.join(self.dataset_folder, f) -# with zipfile.ZipFile(file_path, 'r') as zip_ref: -# zip_ref.extractall(self.dataset_folder) -# self.convert_data() - -# def convert_data(self): -# mat_files = [y for x in os.walk(self.dataset_folder) for y in glob(os.path.join(x[0], '*.mat'))] -# for f in mat_files: -# mat_dict = sio.loadmat(f) -# output_ = np.concatenate((mat_dict['emg'], mat_dict['restimulus'], mat_dict['rerepetition']), axis=1) -# mask_ids = output_[:,11] != 0 -# output_ = output_[mask_ids,:] -# np.savetxt(f[:-4]+'.csv', output_,delimiter=',') - -# def cleanup_data(self): -# mat_files = [y for x in os.walk(self.dataset_folder) for y in glob(os.path.join(x[0], '*.mat'))] -# zip_files = [y for x in os.walk(self.dataset_folder) for y in glob(os.path.join(x[0], '*.zip'))] -# files = mat_files + zip_files -# for f in files: -# os.remove(f) - -# def prepare_data(self, format=OfflineDataHandler): -# if format == OfflineDataHandler: -# classes_values = list(range(1,24)) -# classes_column = [10] -# classset_values = [str(i) for i in list(range(1,4))] -# classset_regex = make_regex(left_bound="_E", right_bound=".csv", values=classset_values) -# reps_values = list(range(1,11)) - -# reps_column = [11] -# subjects_values = [str(s) for s in self.subjects] -# subjects_regex = make_regex(left_bound="S", right_bound="_A", values=subjects_values) -# data_column = list(range(0,10)) -# dic = { -# "reps": reps_values, -# "reps_column": reps_column, -# "classes": classes_values, -# "classes_column": classes_column, -# "subjects": subjects_values, -# "subjects_regex": subjects_regex, -# "classset": classset_values, -# "classset_regex": classset_regex, -# "data_column": data_column -# } -# odh = OfflineDataHandler() -# odh.get_data(folder_location=self.dataset_folder, filename_dic=dic, delimiter=",") -# return odh + if not regression: + clf = EMGClassifier(model) + else: + clf = EMGRegressor(model) + clf.fit(ds) + + if regression: + preds = clf.run(test_feats) + else: + preds, _ = clf.run(test_feats) + + metrics = om.extract_offline_metrics(metrics, test_meta[label_val], preds) + accs.append(metrics) + + print(metrics) + accuracies[d] = accs + + with open(output_file, 'wb') as handle: + pickle.dump(accuracies, handle, protocol=pickle.HIGHEST_PROTOCOL) + + return accuracies \ No newline at end of file diff --git a/libemg/emg_predictor.py b/libemg/emg_predictor.py index 2528861c..4e2ce291 100644 --- a/libemg/emg_predictor.py +++ b/libemg/emg_predictor.py @@ -7,6 +7,7 @@ from sklearn.naive_bayes import GaussianNB from sklearn.neural_network import MLPClassifier, MLPRegressor from sklearn.svm import SVC, SVR +from sklearn.preprocessing import StandardScaler from libemg.feature_extractor import FeatureExtractor from libemg.shared_memory_manager import SharedMemoryManager from multiprocessing import Process, Lock @@ -585,6 +586,10 @@ class OnlineStreamer(ABC): If True, prints predictions to std_out. tcp: bool (optional), default = False If True, will stream predictions over TCP instead of UDP. + feature_queue_length: int (optional), default = 0 + Number of windows to include in online feature queue. Used for time series models that make a prediction on a sequence of feature windows + (batch x feature_queue_length x features) instead of raw EMG. If the value is greater than 0, creates a queue and passes the data to the model as + a 1 x feature_queue_length x num_features. If the value is 0, no feature queue is created and predictions are made on a single window (1 x features). Defaults to 0. """ def __init__(self, @@ -597,7 +602,9 @@ def __init__(self, features, port, ip, std_out, - tcp): + tcp, + feature_queue_length): + self.window_size = window_size self.window_increment = window_increment self.odh = online_data_handler @@ -605,7 +612,9 @@ def __init__(self, self.port = port self.ip = ip self.predictor = offline_predictor - + self.feature_queue_length = feature_queue_length + self.queue = deque(maxlen=feature_queue_length) if self.feature_queue_length > 0 else None + self.scaler = None self.options = {'file': file, 'file_path': file_path, 'std_out': std_out} @@ -637,49 +646,6 @@ def start_stream(self, block=True): self._run_helper() else: self.process.start() - - def write_output(self, prediction, probabilities, probability, calculated_velocity, model_input): - time_stamp = time.time() - if calculated_velocity == "": - printed_velocity = "-1" - else: - printed_velocity = float(calculated_velocity) - if self.options['std_out']: - print(f"{int(prediction)} {printed_velocity} {time.time()}") - # Write classifier output: - if self.options['file']: - if not 'file_handle' in self.files.keys(): - self.files['file_handle'] = open(self.options['file_path'] + 'classifier_output.txt', "a", newline="") - writer = csv.writer(self.files['file_handle']) - feat_str = str(model_input[0]).replace('\n','')[1:-1] - row = [f"{time_stamp} {prediction} {probability[0]} {printed_velocity} {feat_str}"] - writer.writerow(row) - self.files['file_handle'].flush() - if "smm" in self.options.keys(): - # assumed to have "classifier_input" and "classifier_output" keys - # these are (1+) - def insert_classifier_input(data): - input_size = self.options['smm'].variables['classifier_input']["shape"][0] - data[:] = np.vstack((np.hstack([time_stamp, model_input[0]]), data))[:input_size,:] - return data - def insert_classifier_output(data): - output_size = self.options['smm'].variables['classifier_output']["shape"][0] - data[:] = np.vstack((np.hstack([time_stamp, prediction, probability[0], float(printed_velocity)]), data))[:output_size,:] - return data - self.options['smm'].modify_variable("classifier_input", - insert_classifier_input) - self.options['smm'].modify_variable("classifier_output", - insert_classifier_output) - self.options['classifier_smm_writes'] += 1 - - if self.output_format == "predictions": - message = str(prediction) + calculated_velocity + '\n' - elif self.output_format == "probabilities": - message = ' '.join([f'{i:.2f}' for i in probabilities[0]]) + calculated_velocity + " " + str(time_stamp) - if not self.tcp: - self.sock.sendto(bytes(message, 'utf-8'), (self.ip, self.port)) - else: - self.conn.sendall(str.encode(message)) def prepare_smm(self): for i in self.smm_items: @@ -755,14 +721,13 @@ def _run_helper(self): self.odh.prepare_smm() - if self.features is not None: - fe = FeatureExtractor() self.expected_count = {mod:self.window_size for mod in self.odh.modalities} # todo: deal with different sampling frequencies for different modalities self.odh.reset() files = {} + fe = FeatureExtractor() while True: if self.smm: if not self.options["smm"].get_variable("active_flag")[0,0]: @@ -782,16 +747,29 @@ def _run_helper(self): window = {mod:get_windows(data[mod], self.window_size, self.window_increment) for mod in self.odh.modalities} # Dealing with the case for CNNs when no features are used - if self.features: + if self.features is not None: model_input = None for mod in self.odh.modalities: # todo: features for each modality can be different - mod_features = fe.extract_features(self.features, window[mod], self.predictor.feature_params) - mod_features = self._format_data_sample(mod_features) + mod_features = fe.extract_features(self.features, window[mod], array=True) if model_input is None: model_input = mod_features else: model_input = np.hstack((model_input, mod_features)) + + if self.scaler is not None: + model_input = self.scaler.transform(model_input) + + if self.queue is not None: + # Queue features from previous windows + self.queue.append(model_input) # oldest windows will automatically be dequeued if length exceeds maxlen + + if len(self.queue) < self.feature_queue_length: + # Skip until buffer fills up + continue + + model_input = np.concatenate(self.queue, axis=0) + model_input = np.expand_dims(model_input, axis=0) # cast to 3D here for time series models else: model_input = window[list(window.keys())[0]] #TODO: Change this @@ -801,6 +779,23 @@ def _run_helper(self): self.write_output(model_input, window) + def install_standardization(self, standardization: np.ndarray | StandardScaler): + """Install standardization to online model. Standardizes each feature based on training data (i.e., standardizes across windows). + Standardization is only applied when features are extracted and is applied before feature queueing (i.e., features are standardized then queued) + if relevant. + To standardize data, use the standardize Filter. + + :param standardization: Standardization data. If an array, creates a scaler and fits to the provided array. If a StandardScaler, uses the StandardScaler. + :type standardization: np.ndarray | StandardScaler + """ + scaler = standardization + + if not isinstance(scaler, StandardScaler): + # Fit scaler to provided data + scaler = StandardScaler().fit(np.array(standardization)) + + self.scaler = scaler + # ----- All of these are unique to each online streamer ---------- def run(self): pass @@ -856,12 +851,16 @@ class OnlineEMGClassifier(OnlineStreamer): If True, will stream predictions over TCP instead of UDP. output_format: str (optional), default=predictions If predictions, it will broadcast an integer of the prediction, if probabilities it broacasts the posterior probabilities + feature_queue_length: int (optional), default = 0 + Number of windows to include in online feature queue. Used for time series models that make a prediction on a sequence of feature windows + (batch x feature_queue_length x features) instead of raw EMG. If the value is greater than 0, creates a queue and passes the data to the model as + a 1 x feature_queue_length x num_features. If the value is 0, no feature queue is created and predictions are made on a single window (1 x features). Defaults to 0. """ def __init__(self, offline_classifier, window_size, window_increment, online_data_handler, features, file_path = '.', file=False, smm=False, smm_items= None, port=12346, ip='127.0.0.1', std_out=False, tcp=False, - output_format="predictions"): + output_format="predictions", feature_queue_length = 0): if smm_items is None: smm_items = [ @@ -871,7 +870,7 @@ def __init__(self, offline_classifier, window_size, window_increment, online_dat assert 'classifier_input' in [item[0] for item in smm_items], f"'model_input' tag not found in smm_items. Got: {smm_items}." assert 'classifier_output' in [item[0] for item in smm_items], f"'model_output' tag not found in smm_items. Got: {smm_items}." super(OnlineEMGClassifier, self).__init__(offline_classifier, window_size, window_increment, online_data_handler, - file_path, file, smm, smm_items, features, port, ip, std_out, tcp) + file_path, file, smm, smm_items, features, port, ip, std_out, tcp, feature_queue_length) self.output_format = output_format self.previous_predictions = deque(maxlen=self.predictor.majority_vote) self.smi = smm_items @@ -1075,10 +1074,14 @@ class OnlineEMGRegressor(OnlineStreamer): If True, prints predictions to std_out. tcp: bool (optional), default = False If True, will stream predictions over TCP instead of UDP. + feature_queue_length: int (optional), default = 0 + Number of windows to include in online feature queue. Used for time series models that make a prediction on a sequence of windows instead of raw EMG. + If the value is greater than 0, creates a queue and passes the data to the model as a 1 (window) x feature_queue_length x num_features. + If the value is 0, no feature queue is created and predictions are made on a single window. Defaults to 0. """ def __init__(self, offline_regressor, window_size, window_increment, online_data_handler, features, file_path = '.', file = False, smm = False, smm_items = None, - port=12346, ip='127.0.0.1', std_out=False, tcp=False): + port = 12346, ip = '127.0.0.1', std_out = False, tcp = False, feature_queue_length = 0): if smm_items is None: # I think probably just have smm_items default to None and remove the smm flag. Then if the user wants to track stuff, they can pass in smm_items and a function to handle them? smm_items = [ @@ -1088,7 +1091,7 @@ def __init__(self, offline_regressor, window_size, window_increment, online_data assert 'model_input' in [item[0] for item in smm_items], f"'model_input' tag not found in smm_items. Got: {smm_items}." assert 'model_output' in [item[0] for item in smm_items], f"'model_output' tag not found in smm_items. Got: {smm_items}." super(OnlineEMGRegressor, self).__init__(offline_regressor, window_size, window_increment, online_data_handler, file_path, - file, smm, smm_items, features, port, ip, std_out, tcp) + file, smm, smm_items, features, port, ip, std_out, tcp, feature_queue_length) self.smi = smm_items def run(self, block=True): diff --git a/libemg/feature_extractor.py b/libemg/feature_extractor.py index 1e513d8e..f174a292 100644 --- a/libemg/feature_extractor.py +++ b/libemg/feature_extractor.py @@ -1397,15 +1397,18 @@ def getWENGfeat(self, windows, WENG_fs = 1000): list The computed features associated with each window. Size: Wx((order+1)*Nchannels) """ - # get the highest power of 2 the nyquist rate is divisible by - order = math.floor(np.log(WENG_fs/2)/np.log(2) - 1) - # Khushaba et al suggests using sym8 - # note, this will often throw a WARNING saying the user specified order is too high -- but this is what the - # original paper suggests using as the order. - wavelets = wavedec(windows, wavelet='sym8', level=order,axis=2) - # for every order, compute the energy (sum of DWT) - total of the squared signal - features = np.hstack([np.log(np.sum(i**2, axis=2)+1e-10) for i in wavelets]) - return features + import warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # get the highest power of 2 the nyquist rate is divisible by + order = math.floor(np.log(WENG_fs/2)/np.log(2) - 1) + # Khushaba et al suggests using sym8 + # note, this will often throw a WARNING saying the user specified order is too high -- but this is what the + # original paper suggests using as the order. + wavelets = wavedec(windows, wavelet='sym8', level=order,axis=2) + # for every order, compute the energy (sum of DWT) - total of the squared signal + features = np.hstack([np.log(np.sum(i**2, axis=2)+1e-10) for i in wavelets]) + return features def getWVfeat(self, windows, WV_fs=1000): diff --git a/libemg/streamers.py b/libemg/streamers.py index 2543d065..18b44be0 100644 --- a/libemg/streamers.py +++ b/libemg/streamers.py @@ -15,13 +15,13 @@ from libemg._streamers._sifi_bridge_streamer import SiFiBridgeStreamer from libemg._streamers._leap_streamer import LeapStreamer -def sifibridge_streamer(version="1_1", +def sifi_biopoint_streamer(name="BioPoint_v1_3", shared_memory_items = None, - ecg=False, + ecg=True, emg=True, - eda=False, - imu=False, - ppg=False, + eda=True, + imu=True, + ppg=True, notch_on=True, notch_freq=60, emg_fir_on = True, emg_fir=[20,450], @@ -30,26 +30,140 @@ def sifibridge_streamer(version="1_1", fc_hp = 5, # high pass eda freq = 250,# eda sampling frequency streaming=False, - mac= None): + mac= None, + bridge_version = "0.6.4"): # TODO, replace bridge_version with none after Sifi updates + """The streamer for the sifi biopoint. + This function connects to the sifi bridge and streams its data to the SharedMemory. This is used + for the SiFi biopoint. + Note that the IMU is acc_x, acc_y, acc_z, quat_w, quat_x, quat_y, quat_z. + Parameters + ---------- + name: string (option), default = 'BioPoint_v1_3' + The name for the sifi device. + shared_memory_items, default = [] + The key, size, datatype, and multiprocessing Lock for all data to be shared between processes. + ecg, default = True + The flag to enable electrocardiography recording from the main sensor unit. + emg, default = True + The flag to enable electromyography recording. + eda, default = True + The flag to enable electrodermal recording. + imu, default = True + The flag to enable inertial measurement unit recording + ppg, default = True + The flag to enable photoplethysmography recording + notch_on, default = True + The flag to enable a fc Hz notch filter on device (firmware). + notch_freq, default = 60 + The cutoff frequency of the notch filter specified by notch_on. + emg_fir_on, default = True + The flag to enable a bandpass filter on device (firmware). + emg_fir, default = [20, 450] + The low and high cutoff frequency of the bandpass filter specified by emg_fir_on. + eda_cfg, default = True + The flag to specify if using high or low frequency current for EDA or bioimpedance. + fc_lp, default = 0 + The low cutoff frequency for the bioimpedance. + fc_hp, default = 5 + The high cutoff frequency for the bioimpedance. + freq, default = 250 + The sampling frequency for bioimpedance. + streaming, default = False + Whether to package the modalities together within packets for lower latency. + mac, default = None: + mac address of the device to be connected to + Returns + ---------- + Object: streamer + The sifi streamer process object. + Object: shared memory + The shared memory items list to be passed to the OnlineDataHandler. + + Examples + --------- + >>> streamer, shared_memory = sifibridge_streamer() + """ + + if shared_memory_items is None: + shared_memory_items = [] + if emg: + shared_memory_items.append(["emg", (4000,1), np.double]) + shared_memory_items.append(["emg_count", (1,1), np.int32]) + if imu: + shared_memory_items.append(["imu", (200,7), np.double]) + shared_memory_items.append(["imu_count", (1,1), np.int32]) + if ecg: + shared_memory_items.append(["ecg", (1000,1), np.double]) + shared_memory_items.append(["ecg_count", (1,1), np.int32]) + if eda: + shared_memory_items.append(["eda", (200,1), np.double]) + shared_memory_items.append(["eda_count", (1,1), np.int32]) + if ppg: + shared_memory_items.append(["ppg", (200,4), np.double]) + shared_memory_items.append(["ppg_count", (1,1), np.int32]) + + for item in shared_memory_items: + item.append(Lock()) + sb = SiFiBridgeStreamer(name=name, + shared_memory_items=shared_memory_items, + notch_on=notch_on, + ecg=ecg, + emg=emg, + eda=eda, + imu=imu, + ppg=ppg, + notch_freq=notch_freq, + emgfir_on=emg_fir_on, + emg_fir = emg_fir, + eda_cfg = eda_cfg, + fc_lp = fc_lp, # low pass eda + fc_hp = fc_hp, # high pass eda + freq = freq,# eda sampling frequency + streaming=streaming, + mac = mac, + bridge_version=bridge_version) + sb.start() + return sb, shared_memory_items + + + +def sifi_bioarmband_streamer(name="BioPoint_v1_1", + shared_memory_items = None, + ecg=True, + emg=True, + eda=True, + imu=True, + ppg=True, + notch_on=False,#I'm pretty sure these aren't configured right for 1500Hz + notch_freq=60, + emg_fir_on = False,#I'm pretty sure these aren't configured right for 1500Hz + emg_fir=[20,450], + eda_cfg = True, + fc_lp = 0, # low pass eda + fc_hp = 5, # high pass eda + freq = 250,# eda sampling frequency + streaming=False, + mac= None, + bridge_version = "0.6.4"):# TODO, replace bridge_version with none after Sifi updates """The streamer for the sifi armband. This function connects to the sifi bridge and streams its data to the SharedMemory. This is used - for the SiFi biopoint and bioarmband. + for the SiFi bioarmband. Note that the IMU is acc_x, acc_y, acc_z, quat_w, quat_x, quat_y, quat_z. Parameters ---------- - version: string (option), default = '1_1' - The version for the sifi streamer. + name: string (option), default = 'BioPoint_v1_1' + The name for the sifi device. shared_memory_items, default = [] The key, size, datatype, and multiprocessing Lock for all data to be shared between processes. - ecg, default = False + ecg, default = True The flag to enable electrocardiography recording from the main sensor unit. emg, default = True The flag to enable electromyography recording. - eda, default = False + eda, default = True The flag to enable electrodermal recording. - imu, default = False + imu, default = True The flag to enable inertial measurement unit recording - ppg, default = False + ppg, default = True The flag to enable photoplethysmography recording notch_on, default = True The flag to enable a fc Hz notch filter on device (firmware). @@ -89,21 +203,21 @@ def sifibridge_streamer(version="1_1", shared_memory_items.append(["emg", (3000,8), np.double]) shared_memory_items.append(["emg_count", (1,1), np.int32]) if imu: - shared_memory_items.append(["imu", (100,10), np.double]) + shared_memory_items.append(["imu", (200,7), np.double]) shared_memory_items.append(["imu_count", (1,1), np.int32]) if ecg: - shared_memory_items.append(["ecg", (100,10), np.double]) + shared_memory_items.append(["ecg", (1000,1), np.double]) shared_memory_items.append(["ecg_count", (1,1), np.int32]) if eda: - shared_memory_items.append(["eda", (100,10), np.double]) + shared_memory_items.append(["eda", (200,1), np.double]) shared_memory_items.append(["eda_count", (1,1), np.int32]) if ppg: - shared_memory_items.append(["ppg", (100,10), np.double]) + shared_memory_items.append(["ppg", (200,4), np.double]) shared_memory_items.append(["ppg_count", (1,1), np.int32]) for item in shared_memory_items: item.append(Lock()) - sb = SiFiBridgeStreamer(version=version, + sb = SiFiBridgeStreamer(name=name, shared_memory_items=shared_memory_items, notch_on=notch_on, ecg=ecg, @@ -119,10 +233,14 @@ def sifibridge_streamer(version="1_1", fc_hp = fc_hp, # high pass eda freq = freq,# eda sampling frequency streaming=streaming, - mac = mac) + mac = mac, + bridge_version=bridge_version) sb.start() return sb, shared_memory_items + + + def myo_streamer( shared_memory_items : list | None = None, emg : bool = True, @@ -263,7 +381,7 @@ def delsys_api_streamer(license : str = None, Returns ---------- Object: streamer - The sifi streamer object. + The delsys streamer object. Object: shared memory The shared memory object. Examples @@ -308,7 +426,7 @@ def oymotion_streamer(shared_memory_items : list | None = None, Returns ---------- Object: streamer - The sifi streamer object + The oymotion streamer object Object: shared memory The shared memory object Examples @@ -337,7 +455,7 @@ def oymotion_streamer(shared_memory_items : list | None = None, operating_system = platform.system().lower() # I'm only addressing this atm. - if operating_system == "windows" or operating_system == 'mac': + if operating_system == "windows" or operating_system == 'darwin': oym = Gforce(sampling_rate, res, emg, imu, shared_memory_items) oym.start() else: @@ -362,7 +480,7 @@ def emager_streamer(shared_memory_items = None): Returns ---------- Object: streamer - The sifi streamer object. + The emager streamer object. Object: shared memory The shared memory object. Examples diff --git a/libemg/utils.py b/libemg/utils.py index 765beb29..8422dd5a 100644 --- a/libemg/utils.py +++ b/libemg/utils.py @@ -62,7 +62,7 @@ def _get_fn_windows(data, window_size, window_increment, fn): fn_of_windows = np.apply_along_axis(lambda x: fn(x), axis=2, arr=windows) return fn_of_windows.squeeze() -def make_regex(left_bound, right_bound, values=[]): +def make_regex(left_bound, right_bound, values = None): """Regex creation helper for the data handler. The OfflineDataHandler relies on regexes to parse the file/folder structures and extract data. @@ -74,8 +74,8 @@ def make_regex(left_bound, right_bound, values=[]): The left bound of the regex. right_bound: string The right bound of the regex. - values: list - The values between the two regexes. + values: list or None (optional), default = None + The values between the two regexes. If None, will try to find the values using a wildcard. Defaults to None. Returns ---------- @@ -87,10 +87,16 @@ def make_regex(left_bound, right_bound, values=[]): >>> make_regex(left_bound = "_C_", right_bound="_EMG.csv", values = [0,1,2,3,4,5]) """ left_bound_str = "(?<="+ left_bound +")" - mid_str = "(?:" - for i in values: - mid_str += i + "|" - mid_str = mid_str[:-1] - mid_str += ")" + + if values is None: + # Apply wildcard + mid_str = '(.*?)' + else: + mid_str = "(?:" + for i in values: + mid_str += i + "|" + mid_str = mid_str[:-1] + mid_str += ")" + right_bound_str = "(?=" + right_bound +")" return left_bound_str + mid_str + right_bound_str diff --git a/requirements.txt b/requirements.txt index 619df973..3ded04a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,7 @@ wfdb bleak semantic-version requests +h5py # For Docs sphinx==5.0.0 sphinx_rtd_theme==1.0.0 @@ -28,3 +29,5 @@ dearpygui opencv-python datetime websockets==8.1 +h5py +onedrivedownloader \ No newline at end of file diff --git a/setup.py b/setup.py index 436f1b7d..42b6768a 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,9 @@ "opencv-python", "pythonnet", "bleak", - "dearpygui" + "dearpygui", + "h5py", + "onedrivedownloader" ], keywords=[ "emg",