Main Content

Denoise EEG Signals Using Deep Learning Regression

This example shows how to remove electro-oculogram (EOG) noise from electroencephalogram (EEG) signals using the EEGdenoiseNet benchmark dataset [1] and deep learning regression. The EEGdenoiseNet dataset contains 4514 clean EEG segments and 3400 ocular artifact segments that can be used to synthesize noisy EEG segments with the ground-truth clean EEG (the dataset also contains muscular artifact segments, but these will not be used in this example).

This example uses clean and EOG-contaminated EEG signals to train a long short-term memory (LSTM) model to remove the EOG artifacts. The regression model was trained with raw input signals and with signals transformed by the short-time Fourier transform (STFT). The STFT model improves performance especially at degraded SNR values.

Create the Dataset

The EEGdenoiseNet dataset contains 4514 clean EEG segments and 3400 EOG segments that can be used to generate three datasets for training, validating, and testing a deep learning model. The sample rate of all the signal segments is 256 Hz. For convenience, the dataset has been uploaded to this location: https://ssd.mathworks.com/supportfiles/SPT/data/EEGEOGDenoisingData.zip

Download the dataset using the downloadSupportFile function.

% Download the data
datasetZipFile = matlab.internal.examples.downloadSupportFile('SPT','data/EEGEOGDenoisingData.zip');
datasetFolder = fullfile(fileparts(datasetZipFile),'EEG_EOG_Denoising_Dataset');
if ~exist(datasetFolder,'dir')     
    unzip(datasetZipFile,fileparts(datasetZipFile));
end

After downloading the data, the location in datasetFolder contains two MAT files:

  • EEG_all_epochs.mat contains a matrix with 4514 clean EEG segments of length 512 samples

  • EOG_all_epochs.mat contains a matrix with 3400 EOG segments of length 512 samples

Use the createDataset helper function to generate training, validation, and testing datasets. The function combines clean EEG and EOG signals to generate pairs of clean and noisy EEG segments with different signal-to-noise ratios (SNR). For any EEG and EOG pair you can use the following combination equation to obtain a noisy segment with a given SNR:

noisyEEG=EEG+λEOG

SNR=10log10(rms(EEG)rms(λEOG))

You vary parameter λ to control the artifact power and achieve a particular SNR value.

To create the training dataset createDataset combines the first 2720 pairs of EEG and EOG segments ten times each with random SNRs in the [-7, 2] dB interval for a total of 27200 training pairs. Each training pair is stored in a MAT file inside a folder named train. Each MAT file includes:

  • A clean EEG segment (stored under a variable named EEG)

  • An EOG segment (stored under a variable named EOG)

  • A noisy EEG segment (stored under a variable named noisyEEG)

  • The SNR of the noisy segment (stored under a variable named SNR)

  • The sample rate value of the signal segments (stored under a variable named Fs)

To create the validation dataset createDataset combines the next 340 pairs of the EEG and EOG segments ten times each with random SNRs in the [–7, 2] dB interval for a total of 3400 validation segments. Validation data is stored in MAT files inside a folder named validate. Each MAT file contains the same variables as the ones described for the training set.

Finally, to create the test dataset createDataset combines the next 340 pairs of EEG and EOG segments ten times each with deterministic SNR values of –7, –6, –5, –4, –3, –2, –1, 0, 1, and 2 dB. The test data is stored in MAT files inside a folder named test. Test MAT files with the same SNR value are grouped under a common subfolder to make it easier to analyze the denoising performance of the trained model for a given SNR. For example, files with test signals with an SNR of -3 dB are stored in a folder with name data_SNR_-3.

Call the createDataset function to create the dataset (this may take a few seconds). Set the createDatasetFlag to false if you already have the dataset in the datasetFolder and want to skip this step.

createDatasetFlag = true;
if createDatasetFlag
    createDataset(datasetFolder);
end

Prepare Datastores to Consume the Data

The generated dataset is quite large (~430 MB), so it is convenient to use datastores to access the data without having to read it all at once into memory. Create signal datastores to access the training and validation data. Use the SignalVariableNames parameter to specify the variables you want to read from the MAT files (in the order you want them read). Also specify the ReadOutputOrientation as "row" to ensure the data is compatible with the LSTM network.

ds_Train = signalDatastore(fullfile(datasetFolder,"train"),SignalVariableNames=["noisyEEG","EEG"],ReadOutputOrientation="row")
ds_Train = 
  signalDatastore with properties:

                       Files:{
                             ' .../supportfiles/SPT/data/EEG_EOG_Denoising_Dataset/train/data_1.mat';
                             ' .../supportfiles/SPT/data/EEG_EOG_Denoising_Dataset/train/data_10.mat';
                             ' .../supportfiles/SPT/data/EEG_EOG_Denoising_Dataset/train/data_100.mat'
                              ... and 27197 more
                             }
                     Folders: {'/home/fboucher/Documents/MATLAB/Examples/R2021b/supportfiles/SPT/data/EEG_EOG_Denoising_Dataset/train'}
    AlternateFileSystemRoots: [0×0 string]
                    ReadSize: 1
         SignalVariableNames: ["noisyEEG"    "EEG"]
       ReadOutputOrientation: "row"

ds_Validate = signalDatastore(fullfile(datasetFolder,"validate"),SignalVariableNames=["noisyEEG","EEG"],ReadOutputOrientation="row");

Read the data from the first training file and plot the clean and noisy EEG signals. A call to preview or read methods of the datastore yields a 1x2 cell array with the first element containing a noisy EEG segment, and the second element containing a clean EEG segment.

data = preview(ds_Train)
data=1×2 cell array
    {[293.5459 312.8255 158.2003 54.3755 -122.9328 -245.2081 -263.6249 -231.0821 -265.4603 -326.4666 -382.5909 -399.4084 -271.8957 -35.5523 110.8435 163.8853 276.0972 411.6580 484.9692 579.9671 704.4224 715.7357 624.8036 565.6525 510.3934 346.0724 90.4347 -155.6278 -310.0931 -338.3046 -310.9277 -343.7870 -422.5112 -462.0371 -499.4786 -578.8647 -588.3763 -497.5128 -492.0669 -614.5556 -661.3109 -579.5487 -480.8576 -331.8418 -131.9308 -64.3486 -92.1508 22.4674 179.0017 131.3366 21.0167 47.0048 72.9048 8.7219 -27.1450 -37.5581 -67.9045 -21.2165 68.1191 12.8107 -151.7520 -234.3676 -217.8296 -185.2244 -164.2496 -186.3218 -259.2529 -330.6853 -405.8061 -503.9750 -501.7492 -288.2392 16.4508 253.6420 466.4391 717.8048 904.1583 974.9127 1.0648e+03 1.2244e+03 1.3335e+03 1.3712e+03 1.4139e+03 1.4144e+03 1.3391e+03 1.3401e+03 1.5020e+03 1.6821e+03 1.7543e+03 1.7582e+03 1.7833e+03 1.8473e+03 1.8685e+03 1.7850e+03 1.6640e+03 1.5618e+03 1.4678e+03 1.4627e+03 1.5648e+03 1.5608e+03 1.4272e+03 1.4523e+03 1.6068e+03 1.5567e+03 1.3658e+03 1.3158e+03 1.3204e+03 1.2457e+03 1.2205e+03 1.2400e+03 1.1601e+03 1.0654e+03 1.0684e+03 1.0682e+03 1.0135e+03 995.4703 1.0462e+03 1.1054e+03 1.0883e+03 991.5936 946.6780 964.2951 876.7881 744.0202 754.4313 725.2564 462.8552 282.7478 422.9020 566.9241 473.6554 335.4952 296.2394 260.1373 181.3557 165.5355 299.4533 452.6924 393.9423 205.3760 191.1296 311.4351 300.4170 211.6092 214.6472 217.2602 164.5827 195.2764 297.2532 330.1215 322.6244 340.3149 298.3817 183.1748 170.3522 309.5276 374.3551 265.5430 226.5715 340.0638 315.7039 136.6732 165.8833 344.5902 248.3717 -27.9482 -56.8646 109.7881 162.1398 105.1915 88.9519 110.0767 69.5323 -58.8971 -116.2790 -0.3552 46.8582 -126.7452 -229.6075 -141.9099 -154.0355 -296.3941 -354.4280 -386.1300 -419.9606 -274.0139 -66.4493 -44.1554 -85.8933 -88.3121 -162.2650 -186.2820 -37.5754 38.3777 -60.2507 -153.7257 -286.3040 -518.1376 -640.6470 -661.4028 -784.5911 -908.6936 -953.8550 -1.1159e+03 -1.3089e+03 -1.2774e+03 -1.2213e+03 -1.3786e+03 -1.5198e+03 -1.4902e+03 -1.4564e+03 -1.4443e+03 -1.3975e+03 -1.4399e+03 -1.5792e+03 -1.6399e+03 -1.6302e+03 -1.6650e+03 -1.6882e+03 -1.6382e+03 -1.5771e+03 -1.5110e+03 -1.4109e+03 -1.3619e+03 -1.4156e+03 -1.4516e+03 -1.3837e+03 -1.2984e+03 -1.2583e+03 -1.1969e+03 -1.0818e+03 -949.8370 -785.9617 -593.8683 -484.6091 -493.6924 -486.5894 -376.9322 -254.3787 -246.5404 -358.0243 -457.8937 -455.3159 -407.6992 -323.9139 -143.8645 -47.9140 -238.7738 -424.6188 -264.3816 -33.7708 -63.5278 -127.1567 -26.2163 21.3820 -79.8131 -156.4246 -169.0789 -204.3369 -265.7965 -319.7542 -345.6645 -347.2635 -346.8830 -298.6344 -193.2218 -175.2466 -263.6097 -282.8454 -246.9951 -273.7662 -255.3390 -160.0054 -156.2782 -160.6513 -21.3900 25.6993 -149.2867 -243.8079 -170.5436 -154.2558 -159.1055 -70.4898 -55.6186 -184.9485 -282.5962 -283.6852 -229.9518 -144.0285 -136.0738 -251.5972 -323.3831 -270.5175 -208.4219 -197.8861 -240.1977 -352.8216 -455.6938 -466.6105 -475.8638 -562.5412 -620.9217 -541.7099 -397.5501 -349.9958 -409.6095 -420.5214 -369.7480 -409.6981 -468.8235 -365.6404 -287.5636 -484.4784 -723.7460 -684.6297 -471.0508 -334.9867 -365.6236 -522.4700 -638.5261 -585.0912 -480.4407 -440.5247 -381.5005 -282.3495 -236.0896 -240.9890 -269.0734 -319.2842 -320.4942 -287.2268 -363.4684 -514.1044 -569.9932 -541.1071 -496.5918 -411.5588 -337.3943 -348.2844 -369.4418 -337.1498 -291.6518 -251.9061 -228.0908 -212.9140 -159.0688 -167.3810 -345.9673 -435.5995 -223.3164 -35.1722 -90.2426 -77.3739 76.3905 48.6435 -101.3630 -96.3987 -72.6505 -149.6118 -83.4186 118.1417 134.9427 11.3522 52.0115 234.5030 340.0810 350.6371 371.3420 399.1974 318.5930 109.3049 -26.2028 45.9424 146.2610 156.2677 236.4964 355.1251 247.6928 17.5845 -13.5088 51.5868 -16.7073 -56.6772 55.5756 108.1590 52.7701 78.7909 165.4370 175.4213 124.4899 25.4918 -105.4091 -121.8664 -17.4514 34.9487 52.8456 115.3902 88.9050 -18.9950 21.7647 156.9900 163.9505 119.2607 136.0510 123.1712 92.3660 90.3993 3.7770 -101.5293 -30.4743 77.8997 85.4559 200.3163 381.0650 305.4995 144.0639 243.1847 344.7070 166.2973 30.2100 181.1394 362.4509 376.9039 315.4304 277.4177 286.2399 295.2399 250.6823 252.2333 380.9921 475.7625 430.3416 383.2322 386.2697 341.6953 283.4970 299.4474 331.3131 303.5586 236.8910 185.7528 198.4428 256.0914 265.2153 197.2732 106.0651 42.6344 86.7558 248.4670 328.6696 217.0654 115.8127 134.7095 123.4219 88.0805 150.4634 172.4213 51.2537 2.9601 92.3141 124.4305 141.9237 263.2068 290.5031 176.4156 219.6978 370.3300 299.7149 172.6634 282.4104 368.8589 218.3248 102.9427 103.7194 35.4401 13.2908 120.6106 84.7856 -83.2411 -58.2985 78.9139 75.5260 48.0121 82.4030 47.5345 25.0876 139.5917 236.7479 227.8253 186.5794 90.9769 -2.7247 27.1747 18.1979 -137.4836 -170.3707 -35.5876 -33.7408 -118.0647 -38.4254 85.5209 98.7000 111.0841 147.8175 155.6366 195.0901 233.1084 198.1136 180.9826 170.8508 42.2914 -65.1522]}    {[184.5071 182.3164 41.0644 -55.5457 -155.6309 -221.9838 -282.9218 -354.7277 -437.7731 -487.9534 -520.1615 -506.2143 -364.6989 -163.5820 -50.7511 26.8912 194.8870 390.9182 517.5795 593.4634 612.2013 553.5140 510.8928 534.3316 489.5771 283.2516 34.5120 -127.4337 -241.1786 -363.7183 -434.6263 -427.8548 -443.9207 -516.2359 -550.6356 -534.8068 -542.8537 -555.2995 -537.8079 -556.3362 -615.4577 -604.2526 -484.4763 -320.7209 -172.4580 -73.5406 -1.8568 99.4586 184.0984 160.4643 74.4949 10.5636 -49.4240 -109.0210 -128.1498 -128.4013 -118.2926 -59.3559 -7.2351 -44.3925 -118.2087 -159.5074 -187.7227 -207.9291 -217.5568 -252.9733 -292.6216 -298.9978 -342.3380 -450.3554 -489.6292 -382.5939 -218.6430 -74.2266 52.2691 138.1272 139.1030 70.6933 1.0145 -42.4630 -103.4616 -185.9607 -225.8533 -252.2135 -363.8540 -489.5758 -475.7173 -386.3537 -378.1933 -409.8378 -372.8535 -295.8953 -228.4253 -173.7187 -188.4879 -285.0080 -298.9211 -127.3412 35.8319 -8.3398 -118.2478 -78.7464 21.3938 24.1924 11.0743 63.1065 85.5681 67.8093 107.6461 164.9735 164.7395 168.9144 203.8057 210.9855 219.1537 276.0164 339.8352 392.1866 433.9953 414.4877 353.6295 330.8243 318.6424 277.9611 254.5348 223.2527 137.1390 95.9530 146.3726 145.0219 39.1259 -47.4125 -62.2298 -56.4713 -42.3134 4.4354 81.0353 126.2602 81.0397 30.0799 91.9103 167.9870 122.0614 58.4736 77.8752 70.7421 -2.9413 -26.6743 13.9483 41.7856 60.6159 94.5616 115.1145 107.3929 102.8801 137.3549 181.8851 180.3480 195.3455 292.6388 324.4309 217.9083 212.5276 355.4434 318.5107 82.7695 22.4755 169.3209 225.7422 174.6523 180.9261 205.8876 153.3403 75.5621 69.4078 132.9424 141.7501 24.3821 -82.4083 -65.1453 -27.1313 -50.0239 -96.2128 -162.3494 -188.0827 -70.3356 61.6508 38.1663 -23.7790 -24.5535 -55.9957 -65.6618 39.2241 115.0479 112.8966 166.4775 181.8789 34.0374 -78.0041 -51.5329 -63.2166 -109.3926 -83.1151 -97.0703 -158.8065 -84.3664 44.6814 32.8693 0.5971 78.9700 142.9246 150.2819 184.4547 175.9413 58.7671 -35.6207 -25.9139 -8.4974 -49.1827 -108.2087 -141.7154 -134.3698 -96.2263 -64.7217 -60.1042 -57.6134 -32.5431 1.9650 28.5149 55.7130 85.1667 111.3595 160.2484 226.6730 227.6562 137.4926 40.9107 -18.8243 -61.1854 -76.2282 -83.4148 -133.8459 -187.3029 -179.6742 -120.5781 -15.8016 81.0734 43.3321 -51.1707 16.9843 152.7320 125.5967 59.9723 124.7820 155.0875 57.7959 0.3257 14.7888 2.8780 -0.1661 25.2573 1.9835 -45.9279 -48.0225 -20.6491 18.3378 34.3649 -44.5040 -146.3061 -117.5063 -0.3587 53.6563 31.0214 18.0257 65.7692 124.1761 83.0233 -24.2939 -43.9572 -13.6894 -65.5889 -98.3299 -22.0335 -18.6543 -161.4866 -259.3167 -249.0026 -226.8589 -190.8240 -142.0858 -153.4440 -180.1550 -130.3496 -45.2894 12.9008 41.2162 21.2666 -25.4712 -20.7545 26.6362 39.2410 19.3302 27.9147 77.1813 112.9529 86.1450 36.0395 27.5247 27.7840 1.0752 5.4876 18.5464 -87.9264 -262.8144 -283.4608 -118.9496 6.6664 -54.1971 -179.8054 -185.7806 -107.2994 -87.5112 -78.2404 25.9276 108.8198 57.6744 -12.1159 -32.8426 -82.5263 -132.0208 -117.6162 -118.9881 -175.0695 -216.7788 -229.4338 -199.6173 -103.7726 -54.1239 -114.6663 -136.1396 -77.4445 -92.9215 -176.1955 -185.7606 -127.9888 -90.1109 -124.5307 -231.0845 -268.6486 -154.2345 -92.5942 -180.1112 -189.5059 -101.0242 -131.3268 -182.7982 -100.3321 -102.8661 -248.7238 -223.4905 -38.4470 -4.9460 -60.2417 11.5309 94.4714 73.1159 70.3551 126.3687 161.1489 162.7253 124.9805 57.2503 45.2784 76.0252 58.4429 51.2018 101.1252 83.8566 16.0189 65.7301 143.5281 73.0663 7.3477 92.5118 134.1665 42.8936 27.1547 126.4113 168.2120 144.8497 128.5777 81.5150 30.5906 49.1867 68.8646 53.4856 80.5853 87.3083 -4.4480 -26.8726 89.1834 120.0574 29.3774 19.3590 46.0898 -38.7869 -106.3669 -58.0703 -42.3876 -84.2495 -49.9379 35.0590 82.8349 113.9988 121.8874 95.8216 112.4760 144.4695 71.4178 -20.7198 50.3754 217.5505 288.2074 248.4080 198.0217 161.5141 108.6792 60.6000 79.7520 160.4816 222.2459 238.6432 242.2156 217.8590 154.8832 115.3816 121.1145 118.1581 87.5098 51.0366 34.3849 69.6186 121.3643 106.6053 50.6188 16.5712 -29.5941 -59.3618 28.4726 133.0784 88.4984 3.6516 29.6105 71.9211 49.0177 28.2880 14.6080 -9.6634 13.2494 48.8916 21.2872 17.1130 98.5594 147.5359 125.7440 132.8005 160.0092 133.7969 88.0235 70.9290 66.9099 57.3967 16.1320 -59.2551 -94.5876 -69.0312 -68.5262 -107.6983 -123.3063 -125.9344 -127.8381 -78.5704 -20.5095 -45.0608 -77.1851 -2.5161 99.9790 109.9419 60.5306 23.1728 8.3185 18.5414 20.5966 -58.5031 -182.2469 -212.0151 -135.8594 -89.1736 -113.9629 -101.7947 5.2091 129.6184 191.2036 200.8505 201.9036 194.7194 161.0028 114.6600 83.6711 40.1643 -29.6090 -84.8615]}

plot([data{2}.' data{1}.'],LineWidth=2)
legend('Clean EEG','EEG with EOG artifact')
axis tight

The performance of a regression network is usually improved if the input and output signals are normalized. You can transform the signal datastores to apply normalization to each signal as it is read from disk. The normalizeData helper function is listed at the end of this example. It simply subtracts the signal mean and divides the result by the signal's standard deviation.

ds_Train_T = transform(ds_Train,@normalizeData);
ds_Validate_T = transform(ds_Validate,@normalizeData);

Train a Regression Model to Denoise EEG Signals

Train a network to denoise signals by passing noisy EEG signals into the network input and requesting the desired EEG clean ground-truth signals at the network output. A long-short term memory (LSTM) architecture is chosen because it is capable of learning features from time sequences.

Define the network architecture: the number of features is set to one as a single sequence is input to the network and a single sequence is output from the network. Use a dropout layer to reduce overfitting of the model on the training data. Use a regression layer as the output layer since the model is being trained to perform regression. Note that normalization must be applied to input and output signals so it is more convenient to use transformed datastores than to use the Normalization option of the sequenceInputLayer that only normalizes the inputs.

numFeatures = 1;
numHiddenUnits = 100;

layers = [
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits)
    dropoutLayer(0.2)
    fullyConnectedLayer(numFeatures)
    regressionLayer];

Define the training option parameters: use an Adam optimizer and choose to shuffle the data at every epoch. Also, specify the validation datastore ds_Validate_T as the source for the validation data.

maxEpochs = 5;
miniBatchSize = 150;

options = trainingOptions('adam', ...
    MaxEpochs=maxEpochs, ...
    MiniBatchSize=miniBatchSize, ...
    InitialLearnRate=0.005, ...
    GradientThreshold=1, ...
    Plots="training-progress", ...
    Shuffle="every-epoch", ...
    Verbose=false,...
    ValidationData=ds_Validate_T ,...
    ValidationFrequency=100, ...
    OutputNetwork="best-validation-loss");

Use the trainNetwork function to train the model. You can directly pass the transformed train datastore into the function because the datastore outputs a 1x2 cell array, with input and output signals, at each call to the read method.

The training steps will take several minutes. You can skip these steps by downloading the pre-trained networks using the selector below. If you want to train the network as the example runs, select 'Train Networks'. If you want to skip the training steps, select 'Download Networks' and a MAT file containing two pre-trained networks -rawNet, and stftNet- will be downloaded into your machine.

trainingFlag = "Train networks";
if trainingFlag == "Train networks"
    rawNet = trainNetwork(ds_Train_T,layers,options);
else
    % Download the pre-trained networks
    modelsZipFile = matlab.internal.examples.downloadSupportFile('SPT','data/EEGEOGDenoisingNetworks.zip');
    modelsFolder = fullfile(fileparts(datasetZipFile),'EEG_EOG_Denoising_Networks');
    if ~exist(modelsFolder,'dir')
        unzip(modelsZipFile,fileparts(modelsZipFile));
    end
    modelsFile = fullfile(modelsFolder,'trainedNetworks.mat');    
    load(modelsFile)
end

Analyze the Denoising Performance of the Trained Model

Use the test dataset to analyze the denoising performance of the rawNet network. Recall that the test dataset contains multiple test files for each SNR value in [-7, -6, -5, -4, -3, -2, -1, 0, 1, 2] dB. The performance metric is chosen as the mean-squared error (MSE) between the clean baseline EEG signal and the denoised EEG signal. The MSE of the clean EEG signal and the noisy EEG signal is also computed to show the worst-case MSE when no denoising is applied. At each SNR compute 340 MSE values for each of the 340 available test EEG segments and obtain the average MSE.

Create a signalDatastore to consume the test data and use a transformed datastore to setup data normalization. Since the data is now inside subfolders of the test folder, specify IncludeSubfolders as true. Further, use the folders2labels function to get the list of folder names for each file in the test dataset so that you can get data for each SNR.

ds_Test = signalDatastore(fullfile(datasetFolder,"test"),SignalVariableNames=["noisyEEG","EEG"],IncludeSubfolders=true,ReadOutputOrientation="row");
ds_Test_T = transform(ds_Test,@normalizeData);

% Get labels that contain the SNR value for each file in the datastore
labels = folders2labels(ds_Test)
labels = 3400×1 categorical
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
     data_SNR_-1 
      ⋮

For each SNR value, denoise the test signals and compute the average MSE value. Use the subset function of the datastore to get a datastore pointing to the data for each SNR. To denoise a signal call the predict function passing the trained network and the noisy data as inputs.

SNRs = (-7:2);
MSE_Denoised_rawNet = zeros(numel(SNRs),1); % Measure denoising performance
MSE_No_Denoise = zeros(numel(SNRs),1); % Measure worst-case MSE when no denoising is applied

for idx = 1:numel(SNRs)
    lblIdx = find(labels == "data_SNR_"+num2str(SNRs(idx)));
    ds_Test_SNR = subset(ds_Test_T,lblIdx); % New datastore pointing to files with current SNR value

    % Denoise the data using the predict function of the trained model
    pred = predict(rawNet,ds_Test_SNR);

    % Use an array datastore to loop over the 340 denoised signals for the
    % current SNR value. Transform the datastore to add the normalization
    % step. 
    ds_Pred = transform(arrayDatastore(pred,OutputType="same"),@normalizeData);   

    mse = 0;
    mseWorstCase = 0;
    cnt = 0;
    while hasdata(ds_Pred)

        testData = read(ds_Test_SNR);
        denoisedData = read(ds_Pred);

        % MSE performance of denoiser - testData{2} contains clean EEG,
        % testData{1} contains noisy EEG.
        mse = mse + sum((testData{2} - denoisedData{1}).^2)/numel(denoisedData{1});

        % Worst-case MSE performance when no denoising is applied. 
        % Convert data to single precession as denoisedData is single
        % precision.
        mseWorstCase = mseWorstCase + sum((single(testData{2}) - single(testData{1})).^2)/numel(testData{1});
        cnt = cnt+1;
    end

    % Average MSE of denoised signals
    MSE_Denoised_rawNet(idx) = mse/cnt;

    % Worst-case average MSE
    MSE_No_Denoise(idx) = mseWorstCase/cnt;
end

Plot the average MSE results.

plot(SNRs,[MSE_No_Denoise,MSE_Denoised_rawNet],LineWidth=2)
xlabel("SNR")
ylabel("Average MSE")
title("Denoising Performance")
legend("Worst-case scenario (no-denoising)","Denoising with rawNet model")

Improve Performance Using Short-Time Fourier Transform Feature Extraction

A common approach to improve performance of a deep learning model is to use extracted features in place of the original raw signal data. The features provide a representation of the input data that makes it easier for the network to learn the most important aspects of the signals.

Choose a short-time Fourier transformation (STFT) with a window length of 64 samples and overlap length of 63 samples. This transformation will effectively create 33 complex features with a length of 449 samples each. LSTM networks do not support complex inputs so the complex features can be separated into real and imaginary components by stacking the real part of the features on top of the imaginary part of the features to yield 66 real features each one of length 449 samples.

winLength = 64;
overlapLength = 63;

The transformSTFT helper function listed at the end of this example normalizes the input signal and then computes its STFT. The function stacks the real and imaginary components to create a real output matrix. Further, if a GPU is available, the function moves the data to the GPU to accelerate the STFT computations and mitigate the increased complexity of computing the transforms. If you do not wish to use the GPU, set useGPUFlag to false.

useGPUFlag = true;

Compute and plot the STFT of a pair of clean and noisy EEG signals using the transformSTFT helper function.

data = preview(ds_Train);
P = transformSTFT(data,winLength,overlapLength,useGPUFlag);
figure
subplot(1,2,1)
h = imagesc(P{2});
h.Parent.CLim = [-40 57];
title('STFT of clean EEG signal')
ylabel("Stacked real and imaginary features")
subplot(1,2,2)
h = imagesc(P{1});
h.Parent.CLim = [-40 57];
ylabel("Stacked real and imaginary features")
title('STFT of noisy EEG signal')

The idea is to train a network so that it can produce denoised STFT signal representations based on STFT inputs corresponding to noisy signals. Note that the target outcome is a denoised signal, not its denoised STFT representation, so a final step must be added to compute the inverse STFT (ISTFT) to recover the denoised signal as depicted on the block diagram below.

The helper function, transformISTFT, listed at the end of this example takes the denoised STFT network output, converts the stacked real and imaginary features back to complex features and computes the inverse STFT. As a final step the function normalizes the resulting signal. If a GPU is available and useGPUFlag is true, the function performs all the computations in the GPU to reduce the processing time.

Create train, validation, and test datastores to apply STFT using the transformSTFT function.

ds_Train_STFT = transform(ds_Train,@(d,wl,ol,fl)transformSTFT(d,winLength,overlapLength,useGPUFlag));
ds_Validate_STFT = transform(ds_Validate,@(d,wl,ol,fl)transformSTFT(d,winLength,overlapLength,useGPUFlag));
ds_Test_STFT = transform(ds_Test,@(d,wl,ol,fl)transformSTFT(d,winLength,overlapLength,useGPUFlag));

Update the network architecture to account for 66 input and output features and specify the new validation data in the training options. Every other network parameter or option is unchanged.

numFeatures = winLength + 2;

layers = [
    sequenceInputLayer(numFeatures)
    lstmLayer(numHiddenUnits)
    dropoutLayer(0.2)
    fullyConnectedLayer(numFeatures)
    regressionLayer];

options.ValidationData = ds_Validate_STFT;

Train the network if trainingFlag is "Train networks".

if trainingFlag == "Train networks"
    stftNet = trainNetwork(ds_Train_STFT,layers,options);
end

Use the trained network to denoise EEG signals using the test data. Compute average MSE values by comparing denoised and clean baseline EEG signals.

MSE_Denoised_stftNet = zeros(numel(SNRs),1); % Measure denoising performance
for idx = 1:numel(SNRs)
    lblIdx = find(labels == "data_SNR_"+num2str(SNRs(idx)));
    % New datastores pointing to files with current SNR value
    ds_Test_SNR = subset(ds_Test_T,lblIdx); % Raw EEG signals to compute MSE
    ds_Test_STFT_SNR = subset(ds_Test_STFT,lblIdx); % STFT transforms
   
    % Denoise the data using the predict function of the trained model. 
    pred = predict(stftNet,ds_Test_STFT_SNR);

    % Use an array datastore to loop over the 340 denoised signals for the
    % current SNR value. Transform the datastore to compute the inverse
    % STFT and recover the actual denoised signal.
    ds_Pred = transform(arrayDatastore(pred,OutputType="same"),@(P,wl,ol)transformISTFT(P,winLength,overlapLength));   

    mse = 0;   
    cnt = 0;
    while hasdata(ds_Pred)

        testData = read(ds_Test_SNR);
        denoisedData = read(ds_Pred);

        % MSE performance of denoiser - testData{2} contains clean EEG
        mse = mse + sum((testData{2} - denoisedData).^2)/numel(denoisedData);
        cnt = cnt+1;
    end

    % Average MSE of denoised signals
    MSE_Denoised_stftNet(idx) = mse/cnt;
end

Plot the average MSE obtained with no denoising, denoising with a network trained with raw input signals, and denoising with a network trained with STFT transformed signals. You can see that the addition of the STFT step has improved the performance especially at the lower SNR values.

figure
plot(SNRs,[MSE_No_Denoise,MSE_Denoised_rawNet,MSE_Denoised_stftNet],LineWidth=2)
xlabel("SNR")
ylabel("Average MSE")
title("Denoising Performance")
legend("Worst-case scenario (no denoising)","Denoising with rawNet model","Denoising with stftNet model")

Plot noisy and denoised signals for different SNRs. The getRandomEEG helper function listed at the end of this example gets a random EEG signal with a specified SNR from the test dataset.

SNR = -7; % dB
data = getRandomEEG(datasetFolder,SNR);
noisyEEG = normalizeData(data{1});
cleanEEG = normalizeData(data{2});
stftInput = transformSTFT(noisyEEG,winLength,overlapLength,useGPUFlag);
denoisedEEG = transformISTFT(predict(stftNet,stftInput),winLength,overlapLength);

plot([cleanEEG.' denoisedEEG.' noisyEEG.'],LineWidth=2)
title("EEG denoising (SNR = " + SNR + " dB)")
legend("Clean EEG", "Denoised EEG","Noisy EEG")
axis tight

Conclusion

In this example you learned how to train a deep network to perform regression for signal denoising. You compared two models, one trained with raw clean and noisy EEG signals, the other trained with features extracted using a short-time Fourier transform. You learned that you can use complex features by stacking their real and imaginary components and treating them as independent real features. The use of STFT sequences provides greater performance improvement at worse SNRs and both approaches converge in performance as the SNR improves.

References

[1] Haoming Zhang, Mingqi Zhao, Chen Wei, Dante Mantini, Zherui Li, Quanying Liu. "A benchmark dataset for deep learning solutions of EEG denoising." https://arxiv.org/abs/2009.11662

Helper Functions

normalizeData - this function normalizes input signals by subtracting the mean and dividing by the standard deviation.

function y = normalizeData(x)
% This function is only intended to support examples in the Signal
% Processing Toolbox. It may be changed or removed in a future release.

if iscell(x)
    y = cell(1,numel(x));
    y{1} = (x{1}-mean(x{1}))/std(x{1});

    if numel(x) == 2
        y{2} = (x{2}-mean(x{2}))/std(x{2});
    end
else
    y = (x - mean(x))/std(x);
end
end

transformSTFT - this function normalizes the signals in input data and computes their short-time Fourier transform. It converts the complex STFT results into a real matrix by stacking the real and imaginary components one on top of the other.

function P = transformSTFT(data,winLength,overlapLength,useGPUFlag)
% This function is only intended to support examples in the Signal
% Processing Toolbox. It may be changed or removed in a future release.

if ~iscell(data)
    data = {data};
end

P = cell(1,numel(data));

x = data{1};
if useGPUFlag
    x = gpuArray(x);
end
x = normalizeData(x);
y = stft(x,Window=rectwin(winLength),OverlapLength=overlapLength,FrequencyRange="onesided");
P{1} = [real(y);imag(y)];

if numel(data) == 2
    x = data{2};
    if useGPUFlag
        x = gpuArray(x);
    end
    x = normalizeData(x);
    y = stft(x,Window=rectwin(winLength),OverlapLength=overlapLength,FrequencyRange="onesided");
    P{2} = [real(y);imag(y)];
end
end

transformISTFT - this function takes a matrix with stacked real and imaginary STFT elements and combines them back to a complex STFT matrix. The function then computes the inverse STFT transform and normalizes the resulting reconstructed signals.

function data = transformISTFT(P,winLength,overlapLength)
% This function is only intended to support examples in the Signal
% Processing Toolbox. It may be changed or removed in a future release.
PP = P{1};
NumRows = size(PP,1);
X = PP(1:NumRows/2,:)+1i*PP(1+NumRows/2:end,:);
data = istft(X,Window=rectwin(winLength),OverlapLength=overlapLength,ConjugateSymmetric=true,FrequencyRange="onesided").';
data = normalizeData(data);
end

createDataset - this function combines clean EEG signal segments with EOG segments to create training, validation and testing datasets to train an EEG denoiser neural network.

function createDataset(dataDir)
% This function is only intended to support examples in the Signal
% Processing Toolbox. It may be changed or removed in a future release.

% Create training, validation, and testing datasets consisting of clean EEG
% signals and noisy EEG signals contaminated by EOG segments. 

load(fullfile(dataDir,"EEG_all_epochs.mat"),"EEG_all_epochs");
load(fullfile(dataDir,"EOG_all_epochs.mat"),"EOG_all_epochs");

EEG_all_epochs = EEG_all_epochs(1:3400,:).';
EOG_all_epochs = EOG_all_epochs.';
Fs = 256;
trainingPercentage = 80;
validationPercentage = 10;
N = size(EEG_all_epochs,2);

% Create a training dataset consisting of mat files containing two signals
% - a clean EEG signal, and an EEG signal contaminated by EOG artifacts.
% Combine each of 2720 pairs of EEG and EOG segments ten times with random
% SNRs in the range -7dB to 2dB to obtain 27200 training segments.

EEG_training = EEG_all_epochs(:,1:N*trainingPercentage/100);
EOG_training = EOG_all_epochs(:,1:N*trainingPercentage/100);

M = size(EEG_training,2);
cnt = 0;
if ~exist(fullfile(dataDir,"train"),'dir')
    mkdir(fullfile(dataDir,"train"))
end
for idx = 1:M
    for kk = 1:10
        cnt = cnt + 1;
        EEG = EEG_training(:,idx).';
        EOG = EOG_training(:,idx).';
        [noisyEEG,SNR] = createNoisySegment(EEG,EOG,[-7,2]);
        save(fullfile(dataDir,"train","data_" + num2str(cnt) + ".mat"),"EEG","EOG","noisyEEG","Fs","SNR");
    end
end

% Create a validation dataset by combining 340 pairs of EEG and EOG
% segments ten times with random SNRs in (-7:2) dB

EEG_validation = EEG_all_epochs(:,1+N*trainingPercentage/100:N*trainingPercentage/100+N*validationPercentage/100);
EOG_validation = EOG_all_epochs(:,1+N*trainingPercentage/100:N*trainingPercentage/100+N*validationPercentage/100);

M = size(EEG_validation,2);
cnt = 0;
if ~exist(fullfile(dataDir,"validate"),'dir')
    mkdir(fullfile(dataDir,"validate"))
end
for idx = 1:M
    for kk = 1:10
        cnt = cnt + 1;
        EEG = EEG_validation(:,idx).';
        EOG = EOG_validation(:,idx).';
        [noisyEEG,SNR] = createNoisySegment(EEG,EOG,[-7,2]);
        save(fullfile(dataDir,"validate","data_" + num2str(cnt) + ".mat"),"EEG","EOG","noisyEEG","Fs","SNR");
    end
end

% Create a test dataset by combining 340 pairs of EEG and EOG segments ten
% times with 10 SNR values [-7 -6 -5 -4 -3 -2 -1 0 1 2] dB. Store the
% training sets in folders with names that identify the SNR value so that
% it is easy to analyze performance by accessing files with a specific SNR.

EEG_test = EEG_all_epochs(:,1+N*trainingPercentage/100+N*validationPercentage/100:end);
EOG_test = EOG_all_epochs(:,1+N*trainingPercentage/100+N*validationPercentage/100:end);

M = size(EEG_test,2);
SNRVect = (-7:2);
for kk = 1:numel(SNRVect)
    cnt = 0;
    if ~exist(fullfile(dataDir,"test","data_SNR_" + num2str(SNRVect(kk))),'dir')
        mkdir(fullfile(dataDir,"test","data_SNR_" + num2str(SNRVect(kk))));
    end
    for idx = 1:M
        cnt = cnt + 1;
        EEG = EEG_test(:,idx).';
        EOG = EOG_test(:,idx).';
        [noisyEEG,SNR] = createNoisySegment(EEG,EOG,SNRVect(kk));
        save(fullfile(dataDir,"test","data_SNR_" + num2str(SNR)+"/" + "data_"+num2str(cnt) + ".mat"),"EEG","EOG","noisyEEG","Fs","SNR");
    end
end
end

function [y,SNROut] = createNoisySegment(eeg,artifact,SNR)
% Combine EEG and artifact signals with a specified SNR in dB. If SNR is a
% two-element vector, its value is chosen randomly from a uniform
% distribution over the interval [SNR(1) SNR(2)]

if numel(SNR) == 2
    SNR = SNR(1) + (SNR(2)-SNR(1)).*rand(1,1);
end

k = 10^(SNR/10);
lambda = (1/k)*rms(eeg)/rms(artifact);

y = eeg + lambda * artifact;

SNROut = SNR;
end

getRandomEEG - this function reads the data from a random EEG test file with a specified SNR.

function data = getRandomEEG(datasetFolder,SNR)
sds  = signalDatastore(fullfile(datasetFolder,"test","data_SNR_"+num2str(SNR)),SignalVariableNames=["noisyEEG","EEG"],IncludeSubfolders=true);
n = numel(sds.Files);
idx = randi(n,1);
data = read(subset(sds,idx));
end

See Also

| | (Deep Learning Toolbox)