Classify Arm Motions Using EMG Signals and Deep Learning
This example shows how to classify forearm motions based on electromyographic (EMG) signals. An EMG signal measures the electrical activity of a muscle when it contracts.
Thirty subjects each participated in four data collection sessions, during which they performed six individual trials of different forearm motions while EMG signals were recorded from eight muscles. The data set consists of 1440 MAT-files: 720 of these files contain signal data and the other 720 files contain corresponding label data. The label data consists of a motion variable, motion
, and an index variable, data_indx
. The data set is organized into subject folders that contain a subfolder for each session. Each session subfolder contains six signal data files and six label data files, corresponding to each trial.
The motion
variable is a numeric array that represents seven different motions:
Hand Open
Hand Close
Wrist Flexion
Wrist Extension
Supination
Pronation
Rest
Each motion was held for three seconds and repeated four times in random order. The first and last element in motion
are equal to –1 and correspond to an extended rest period that was performed at the start and end of each trial. The data_indx
variable contains the start indices of each motion.
You can download the files at this location: https://ssd.mathworks.com/supportfiles/SPT/data/MyoelectricData.zip
.
Create Datastores to Read Signal and Label Data
To access the files, create a signal datastore that points to the location where the files are downloaded. The files containing signal data have names that end with "d
" and the files containing label data have names that end with "i
". The sample rate is 3000 Hz. Create a subset of the datastore containing only signal data.
fs = 3000; localfile = matlab.internal.examples.downloadSupportFile("SPT","data/MyoelectricData.zip"); datasetFolder = fullfile(fileparts(localfile),"MyoelectricData"); if ~exist(datasetFolder,"dir") unzip(localfile,datasetFolder) end sds1 = signalDatastore(datasetFolder,IncludeSubFolders=true,SampleRate=fs); p = endsWith(sds1.Files,"d.mat"); sdssig = subset(sds1,p);
Create a second datastore that points to the same file location and specify the names of the two variables in the label files. Create a subset of this datastore containing only label data.
sds2 = signalDatastore(datasetFolder,SignalVariableNames=["motion";"data_indx"],IncludeSubfolders=true); p = endsWith(sds2.Files,"i.mat"); sdslbl = subset(sds2,p);
Plot all eight channels of the first EMG signal to visualize the activation of each muscle during one trial.
signal = preview(sdssig); for i = 1:8 ax(i) = subplot(4,2,i); plot(signal(:,i)) title("Channel"+i) end linkaxes(ax,"y")
Create ROI Table
Define region-of-interest (ROI) limits for each motion based on the indices in data_indx
. Remove the first and last label values (equal to –1) and convert the remaining numeric labels into a categorical array. Create a table containing the ROI limits in the first column and the labels in the second column.
lbls = {}; i = 1; while hasdata(sdslbl) label = read(sdslbl); idx_start = label{2}(2:end-1)'; idx_end = [idx_start(2:end)-1;idx_start(end)+(3*fs)]; val = categorical(label{1}(2:end-1)',[1 2 3 4 5 6 7], ... ["HandOpen" "HandClose" "WristFlexion" "WristExtension" "Supination" "Pronation" "Rest"]); ROI = [idx_start idx_end]; % In some cases, the number of label values and ROIs are not equal. % To eliminate these inconsistencies, remove the extra label value or ROI limits. if numel(val) < size(ROI,1) ROI(end,:) = []; elseif numel(val) > size(ROI,1) val(end) = []; end lbltable = table(ROI,val); lbls{i} = {lbltable}; i = i+1; end
Prepare Datastore
Create a new datastore containing the modified label data and display the ROI table from the first observation.
lblDS = signalDatastore(lbls); lblstable = preview(lblDS); lblstable{1}
Combine the signal and label data into one datastore.
DS = combine(sdssig,lblDS); combinedData = preview(DS)
Create a signal mask and call plotsigroi
to display the labeled motion regions for the first channel of the first signal. The start and end of the signal shown in black represent the extended rest periods that are removed in the next preprocessing step.
figure msk = signalMask(combinedData{2}); plotsigroi(msk,combinedData{1}(:,1))
Preprocess Data
Transform the combined datastore using the preprocess
function that performs these preprocessing tasks.
Remove extended rest periods from start and end of each signal.
Remove pronation and supination motions. Since EMG was not recorded from the main muscles that enable forearm pronation, and was recorded for only one of the muscles involved in supination, the function excludes these motions from the data.
Remove the rest periods.
Filter signal using bandpass filter with lower cutoff frequency of 10 Hz and higher cutoff frequency of 400 Hz.
Downsample signal and label data to 1000 Hz.
Create signal mask for regions of interest (motions) and labels, where each signal sample has a corresponding label to enable sequence-to-sequence classification.
Break the signals into shorter segments that are 12000 samples in length.
tDS = transform(DS,@preprocess); transformedData = preview(tDS)
Divide Data into Training and Testing Sets
Use 80% of the data to train the network and 20% to test the network. Multiply the random indices by 24 (6 trials x 4 sessions = 24 files for each subject) to avoid including data from a single subject in both training and testing sets.
rng default [trainIdx,~,testIdx] = dividerand(30,0.8,0,0.2); trainIdx_all = {}; m = 1; for k = trainIdx if k == 1 start = k; else start = ((k-1)*24)+1; end l = start:k*24; trainIdx_all{m} = l; m = m+1; end trainIdx_all = cell2mat(trainIdx_all)'; trainDS = subset(tDS,trainIdx_all); testIdx_all = {}; m = 1; for k = testIdx if k == 1 start = k; else start = ((k-1)*24)+1; end l = start:k*24; testIdx_all{m} = l; m = m+1; end testIdx_all = cell2mat(testIdx_all)'; testDS = subset(tDS,testIdx_all);
Train Network
To avoid repeating the preprocessing steps at every training epoch and thus reduce training time, read all the data into memory before training the network. You can speed up this process by reading the data in parallel (requires Parallel Computing Toolbox™).
traindata = readall(trainDS,"UseParallel",true);
Define a convolutional neural network. Specify a fullyConnectedLayer
with an output size of 4 corresponding to one for each type of motion.
layers = [ ... sequenceInputLayer(8) convolution1dLayer(8,32,Stride=2,Padding="same") reluLayer layerNormalizationLayer convolution1dLayer(8,16,Stride=2,Padding="same") reluLayer layerNormalizationLayer transposedConv1dLayer(8,16,Stride=2,Cropping="same") reluLayer layerNormalizationLayer transposedConv1dLayer(8,32,Stride=2,Cropping="same") reluLayer layerNormalizationLayer fullyConnectedLayer(4) softmaxLayer ];
Specify options for network training. Use the Adam optimizer and a mini-batch size of 32. Set the initial learning rate to 0.001 and the maximum number of epochs to 100. Shuffle the data every epoch.
options = trainingOptions("adam", ... MaxEpochs=100, ... MiniBatchSize=32, ... Plots="training-progress",... InitialLearnRate=0.001,... Verbose=0,... Shuffle="every-epoch",... Metric="accuracy");
rawNet = trainnet(traindata(:,1),traindata(:,2),layers,"crossentropy",options);
Classify Testing Signals
Use the trained network to classify the motions for the testing data set. Display the results with a confusion chart.
testdata = readall(testDS);
scores = minibatchpredict(rawNet,testdata(:,1),MiniBatchSize=128);
classNames = categories(traindata{1,2}); predTest = scores2label(scores,classNames);
confusionchart(vertcat(testdata{:,2}),[predTest(:)],Normalization="column-normalized")
Conclusion
This example showed how to perform sequence-to-sequence classification to detect different arm motions based on EMG signals. An overall accuracy of about 84% was achieved using a convolutional network with 80 hidden units. Some misclassification occurred between hand open and wrist extension, and between hand close and wrist flexion. The hand open motion involves the same muscle as one of those used in wrist extension. Similarly, the hand close and wrist flexion motions can activate the same muscles. Further, the placement of EMG electrodes on the arm targeted mostly muscles used in wrist flexion which had the highest classification accuracy.
The data for this example was collected by Professor Chan of Carleton University and can be found here: https://www.sce.carleton.ca/faculty/chan/index.php?page=matlab
[1].
References
[1] Chan, Adrian D.C., and Geoffrey C. Green. 2007. "Myoelectric Control Development Toolbox." Paper presented at 30th Conference of the Canadian Medical & Biological Engineering Society, Toronto, Canada, 2007.
preprocess
Function
function Tsds = preprocess(inputDS) sig = inputDS{1}; roiTable = inputDS{2}; % Remove first and last rest periods from signal sig(roiTable.ROI(end,2):end,:) = []; sig(1:roiTable.ROI(1,1),:) = []; % Shift ROI limits to account for deleting start and end of signal roiTable.ROI = roiTable.ROI-(roiTable.ROI(1,1)-1); % Create signal mask m = signalMask(roiTable); L = length(sig); % Obtain sequence of category labels and remove pronation, supination, and rest motions mask = catmask(m,L); idx = ~ismember(mask,{'Pronation','Supination','Rest'}); mask = mask(idx); sig = sig(idx,:); % Create new signal mask without pronation and supination categories m2 = signalMask(mask); m2.SpecifySelectedCategories = true; % m2.SelectedCategories = [1 2 3 4 7]; m2.SelectedCategories = [1 2 3 4]; mask = catmask(m2); % Filter and downsample signal data sigfilt = bandpass(sig,[10 400],3000); downsig = downsample(sigfilt,3); % Downsample label data downmask = downsample(mask,3); targetLength = 12000; % Get number of chunks numChunks = floor(size(downsig,1)/targetLength); % Truncate signal and mask to integer number of chunks sig = downsig(1:numChunks*targetLength,:); mask = downmask(1:numChunks*targetLength); % Create a cell array containing signal chunks sigOut = {}; step = 0; for i = 1:numChunks sigOut{i,1} = sig(1+step:i*targetLength,:); step = step+targetLength; end % Create a cell array containing mask chunks lblOut = reshape(mask,targetLength,numChunks); lblOut = num2cell(lblOut,1)'; % Output a two-column cell array with all chunks Tsds = [sigOut,lblOut]; end