Main Content

scatteringTransform

Wavelet joint time-frequency scattering transform

Since R2024b

    Description

    [outCFS,outMETA] = scatteringTransform(jtfn,x) returns the joint time-frequency scattering (JTFS) transform outCFS and metadata outMETA of x for the JTFS network jtfn.

    example

    [outCFS,outMETA] = scatteringTransform(___,Name=Value) specifies options using one or more name-value arguments. These arguments can be added to the previous input syntax. For example, to average along the time dimension for all JTFS coefficients, set TimeAverage to "global".

    example

    Examples

    collapse all

    Create a single-precision random signal with three channels and 1024 samples representing a batch of 5. Save the signal as a dlarray in "CTB" format.

    nchan = 3;
    nsam = 1024;
    nbatch = 5;
    sig = single(randn([nchan nsam nbatch]));
    x = dlarray(sig,"CTB");

    Create a JTFS network appropriate for the signal. Set the filter data type of the network to "single".

    jtfn = timeFrequencyScattering(SignalLength=nsam, ...
        FilterDataType="single");

    Use the scatteringTransform function to obtain the JTFS transform of the signal. Also obtain the transform metadata.

    [outCFS,outMETA] = scatteringTransform(jtfn,x);

    Inspect the JTFS coefficient arrays. The format of each coefficient array is path-by-frequency-by-time-by-channel-by-batch.

    outCFS
    outCFS =
    
      dictionary (string ⟼ cell) with 5 entries:
    
        "S1FreqLowpass"       ⟼ {5-D dlarray}
        "S1SpinUpFreqLowpass" ⟼ {5-D dlarray}
        "SpinUp"              ⟼ {5-D dlarray}
        "SpinDown"            ⟼ {5-D dlarray}
        "U2JointLowpass"      ⟼ {5-D dlarray}
    

    If the input signal is a formatted or unformatted dlarray, every dictionary value is an unformatted dlarray. Choose any dictionary value. Confirm that value is an unformatted dlarray and the underlying data type is single precision.

    key = "S1SpinUpFreqLowpass";
    val = outCFS{key};
    dims(val)
    ans =
    
      0×0 empty char array
    
    underlyingType(val)
    ans = 
    'single'
    

    Inspect the SpinUp coefficients array and its metadata. The metadata in the ith table row describes the coefficients outCFS{"SpinUp"}(i,:,:,:,:).

    cfs = outCFS{"SpinUp"};
    [numPath,numFrequency,numTime,numChannel,numBatch] = size(cfs) %#ok<*ASGLU>
    numPath = 
    35
    
    numFrequency = 
    6
    
    numTime = 
    8
    
    numChannel = 
    3
    
    numBatch = 
    5
    
    outMETA{3}
    ans=35×5 table
          type      log2dsfactor     path     spin    log2stride
        ________    ____________    ______    ____    __________
    
        "SpinUp"       0    1       1    3     1        3    7  
        "SpinUp"       0    1       2    3     1        3    7  
        "SpinUp"       1    1       3    3     1        3    7  
        "SpinUp"       2    1       4    3     1        3    7  
        "SpinUp"       2    1       5    3     1        3    7  
        "SpinUp"       0    2       1    4     1        3    7  
        "SpinUp"       0    2       2    4     1        3    7  
        "SpinUp"       1    2       3    4     1        3    7  
        "SpinUp"       2    2       4    4     1        3    7  
        "SpinUp"       2    2       5    4     1        3    7  
        "SpinUp"       0    3       1    5     1        3    7  
        "SpinUp"       0    3       2    5     1        3    7  
        "SpinUp"       1    3       3    5     1        3    7  
        "SpinUp"       2    3       4    5     1        3    7  
        "SpinUp"       2    3       5    5     1        3    7  
        "SpinUp"       0    4       1    6     1        3    7  
          ⋮
    
    

    Inspect the U2JointLowpass coefficients array and its metadata. Because the scatteringTransform function did not use spin-up or spin-down wavelets to compute these coefficients, the spin value for all coefficient paths is 0.

    cfs = outCFS{"U2JointLowpass"};
    [numPath,numFrequency,numTime,numChannel,numBatch] = size(cfs)
    numPath = 
    7
    
    numFrequency = 
    6
    
    numTime = 
    8
    
    numChannel = 
    3
    
    numBatch = 
    5
    
    outMETA{5}
    ans=7×5 table
              type          log2dsfactor      path      spin    log2stride
        ________________    ____________    ________    ____    __________
    
        "U2JointLowpass"         1          -1     3     0        3    7  
        "U2JointLowpass"         2          -1     4     0        3    7  
        "U2JointLowpass"         3          -1     5     0        3    7  
        "U2JointLowpass"         4          -1     6     0        3    7  
        "U2JointLowpass"         5          -1     7     0        3    7  
        "U2JointLowpass"         6          -1     8     0        3    7  
        "U2JointLowpass"         6          -1     9     0        3    7  
    
    

    Load the ECG signal data. The data has 2048 samples. Create a JTFS network appropriate for the signal.

    load wecg
    len = length(wecg);
    jtfn = timeFrequencyScattering(SignalLength=len);

    Obtain the JTFS transform of the signal using default function parameters. Also obtain the transform metadata. By default, scatteringTransform critically downsamples values in time and frequency. Because the data contains one batch of a single-channel signal, the format of the coefficient dictionary values is path-by-frequency-by-time.

    [outCFS,outMETA] = scatteringTransform(jtfn,wecg);
    outCFS
    outCFS =
    
      dictionary (string ⟼ cell) with 5 entries:
    
        "S1FreqLowpass"       ⟼ {1×7×8 double}
        "S1SpinUpFreqLowpass" ⟼ {5×7×8 double}
        "SpinUp"              ⟼ {40×7×8 double}
        "SpinDown"            ⟼ {40×7×8 double}
        "U2JointLowpass"      ⟼ {8×7×8 double}
    

    Obtain the JTFS transform with TimeOversamplingFactor set to 1. Because you specify a time oversampling factor of 1, the size of the time dimension in the coefficient arrays increases by a factor of 2. The sizes of the path and frequency dimensions remain the same.

    [outCFS_T1,outMETA_T1] = scatteringTransform(jtfn,wecg, ...
        TimeOversamplingFactor=1);
    outCFS_T1
    outCFS_T1 =
    
      dictionary (string ⟼ cell) with 5 entries:
    
        "S1FreqLowpass"       ⟼ {1×7×16 double}
        "S1SpinUpFreqLowpass" ⟼ {5×7×16 double}
        "SpinUp"              ⟼ {40×7×16 double}
        "SpinDown"            ⟼ {40×7×16 double}
        "U2JointLowpass"      ⟼ {8×7×16 double}
    

    Compare the first five rows in the "SpinDown" metadata tables. The second column in the log2dsfactor and log2stride table variables indicates the downsampling factor in time. By oversampling in time by 1, those values in the metadata from the second transform have decreased by 1.

    outMETA{4}(1:5,:)
    ans=5×5 table
           type       log2dsfactor      path      spin    log2stride
        __________    ____________    ________    ____    __________
    
        "SpinDown"       0    1        6     3     -1       3    8  
        "SpinDown"       0    1        7     3     -1       3    8  
        "SpinDown"       1    1        8     3     -1       3    8  
        "SpinDown"       2    1        9     3     -1       3    8  
        "SpinDown"       2    1       10     3     -1       3    8  
    
    
    outMETA_T1{4}(1:5,:)
    ans=5×5 table
           type       log2dsfactor      path      spin    log2stride
        __________    ____________    ________    ____    __________
    
        "SpinDown"       0    0        6     3     -1       3    7  
        "SpinDown"       0    0        7     3     -1       3    7  
        "SpinDown"       1    0        8     3     -1       3    7  
        "SpinDown"       2    0        9     3     -1       3    7  
        "SpinDown"       2    0       10     3     -1       3    7  
    
    

    Now obtain the JTFS transform of the signal with FrequencyOversamplingFactor set to 1. Compared with the first transform, the size of the frequency dimension in the coefficient arrays is twice as large. The sizes of the path and time dimensions are the same.

    [outCFS_F1,outMETA_F1] = scatteringTransform(jtfn,wecg, ...
        FrequencyOversamplingFactor=1);
    outCFS_F1
    outCFS_F1 =
    
      dictionary (string ⟼ cell) with 5 entries:
    
        "S1FreqLowpass"       ⟼ {1×14×8 double}
        "S1SpinUpFreqLowpass" ⟼ {5×14×8 double}
        "SpinUp"              ⟼ {40×14×8 double}
        "SpinDown"            ⟼ {40×14×8 double}
        "U2JointLowpass"      ⟼ {8×14×8 double}
    

    Create a single-precision random signal with three channels and 1000 samples representing a batch of 5. For 3-D numeric input, scatteringTransform assumes the dimensions are time-by-channel-by-batch. Save the signal as a gpuArray.

    nsam = 1000;
    nchan = 3;
    nbatch = 5;
    sig = single(randn([nsam nchan nbatch]));
    x = gpuArray(sig);

    Create a JTFS network appropriate for the signal.

    jtfn = timeFrequencyScattering(SignalLength=nsam, ...
        FilterDataType="single");

    Obtain the JTFS transform of the signal using default settings. The scatteringTransform function uses lowpass filtering to obtain the coefficients.

    outCFS = scatteringTransform(jtfn,x)
    outCFS =
    
      dictionary (string ⟼ cell) with 5 entries:
    
        "S1FreqLowpass"       ⟼ {5-D gpuArray}
        "S1SpinUpFreqLowpass" ⟼ {5-D gpuArray}
        "SpinUp"              ⟼ {5-D gpuArray}
        "SpinDown"            ⟼ {5-D gpuArray}
        "U2JointLowpass"      ⟼ {5-D gpuArray}
    

    Obtain the dimensions of the coefficient arrays. The arrays are in path-by-frequency-by-time-by-channel-by-batch format.

    dictionaryValues = values(outCFS);
    cellfun(@size,dictionaryValues,UniformOutput=false)
    ans=5×1 cell array
        {[ 1 6 7 3 5]}
        {[ 5 6 7 3 5]}
        {[35 6 7 3 5]}
        {[35 6 7 3 5]}
        {[ 7 6 7 3 5]}
    
    

    Obtain the JTFS transform of the signal with TimeAverage set to "global". Instead of using lowpass filtering, the function takes the mean along the time dimension for all the coefficients. The size of the time dimension in the coefficient arrays is 1.

    outCFS_T = scatteringTransform(jtfn,x, ...
        TimeAverage="global");
    dictionaryValues_T = values(outCFS_T);
    cellfun(@size,dictionaryValues_T,UniformOutput=false)
    ans=5×1 cell array
        {[ 1 6 1 3 5]}
        {[ 5 6 1 3 5]}
        {[35 6 1 3 5]}
        {[35 6 1 3 5]}
        {[ 7 6 1 3 5]}
    
    

    Obtain the JTFS transform of the signal with FrequencyAverage set to "global". Instead of using lowpass filtering, the function takes the mean along the frequency dimension for all the coefficients. The size of the frequency dimension in the coefficient arrays is 1.

    outCFS_F = scatteringTransform(jtfn,x, ...
        FrequencyAverage="global");
    dictionaryValues_F = values(outCFS_F);
    cellfun(@size,dictionaryValues_F,UniformOutput=false)
    ans=5×1 cell array
        {[ 1 1 7 3 5]}
        {[ 5 1 7 3 5]}
        {[35 1 7 3 5]}
        {[35 1 7 3 5]}
        {[ 7 1 7 3 5]}
    
    

    Obtain the JTFS transform of the signal with TimeAverage and FrequencyAverage both set to "global".

    outCFS_TF = scatteringTransform(jtfn,x, ...
        TimeAverage="global", ...
        FrequencyAverage="global");
    dictionaryValues_TF = values(outCFS_TF);
    cellfun(@size,dictionaryValues_TF,UniformOutput=false)
    ans=5×1 cell array
        {[ 1 1 1 3 5]}
        {[ 5 1 1 3 5]}
        {[35 1 1 3 5]}
        {[35 1 1 3 5]}
        {[ 7 1 1 3 5]}
    
    

    Confirm the underlying data type of the coefficients is single precision.

    dictionaryValues = values(outCFS_TF);
    cellfun(@underlyingType,dictionaryValues,UniformOutput=false)
    ans = 5×1 cell
        {'single'}
        {'single'}
        {'single'}
        {'single'}
        {'single'}
    
    

    Gather the "SpinUp" coefficients from the GPU. Compare with the same coefficients in the JTFS transform of the original random signal. Confirm the coefficients are equal.

    cfs = "SpinUp";
    cfsG = gather(outCFS_TF{cfs});
    
    outCFS_TF_ORIG = scatteringTransform(jtfn,sig, ...
        TimeAverage="global", ...
        FrequencyAverage="global");
    
    cfsO = outCFS_TF_ORIG{cfs};
    max(abs(cfsG(:)-cfsO(:)))
    ans = single
    
    7.4506e-08
    

    Input Arguments

    collapse all

    Joint time-frequency scattering network, specified as a timeFrequencyScattering object.

    Input data, specified as a formatted or unformatted dlarray (Deep Learning Toolbox) object or a numeric array. If x is a formatted dlarray, it must be in "CBT" format. If x is an unformatted dlarray, it must be compatible with "CBT" format and you must set DataFormat.

    If x is 2-D, the scatteringTransform function assumes the first dimension is time and the columns of x are separate channels. If x is 3-D, the dimensions of x are time-by-channel-by-batch.

    • If x is a vector or unformatted dlarray, the number of samples in x must match the SignalLength property of jtfn.

    • If x is a numeric or unformatted matrix or a 3-D array, the number of rows in x must match SignalLength.

    • If x is a formatted dlarray, the length of the time dimension must match SignalLength.

    Data Types: single | double

    Name-Value Arguments

    Specify optional pairs of arguments as Name1=Value1,...,NameN=ValueN, where Name is the argument name and Value is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.

    Example: outCFS = scatteringTransform(jtfn,x,DataFormat="CBT",FrequencyAverage="global") specifies the format of the unformatted dlarray x as "CBT" and takes the mean along the frequency dimension for all JTFS coefficients.

    Coefficients to exclude from the JTFS transform, specified as a string vector or cell array of character vectors. You can specify these coefficients:

    • "S1FreqLowpass" — First-order time scattering coefficients filtered with the frequency lowpass filter

    • "S1SpinUpFreqLowpass" — First-order time scattering coefficients with the spin-up frequency wavelets

    • "SpinUp" — Second-order time scattering coefficients with spin-up wavelets

    • "SpinDown" — Second-order time scattering coefficients with spin-down wavelets

    • "U2JointLowpass" — Second-order time scattering coefficients filtered with joint lowpass filters

    Example: outCFS = scatteringTransform(jtfn,x,ExcludeCoefficients=["S1FreqLowpass" "U2JointLowpass"])

    Time-averaging option, specified as one of these:

    • "local"scatteringTransform uses the lowpass filter when obtaining the JTFS coefficients.

    • "global"scatteringTransform takes the mean along the time dimension for all JTFS coefficients.

    Frequency-averaging option, specified as one of these:

    • "local"scatteringTransform uses the lowpass frequency filter when obtaining the JTFS coefficients.

    • "global"scatteringTransform takes the mean along the frequency dimension for all JTFS coefficients.

    Time oversampling factor, specified as a nonnegative integer. The factor specifies how much the coefficients are oversampled in time with respect to the critically downsampled values. The factor is on a base-2 logarithmic scale.

    If you increase the oversampling factor, the computational costs and memory requirements of the scattering transform also increase.

    Note

    The number of paths in the JTFS network does not depend on the time oversampling factor. This is different from waveletScattering. The value of the OversamplingFactor property in a wavelet scattering network affects the number of paths in the network.

    Data Types: single | double

    Frequency oversampling factor, specified as a nonnegative integer. The factor specifies how much the coefficients are oversampled in frequency with respect to the critically downsampled values. The factor is on a base-2 logarithmic scale.

    If you increase the oversampling factor, the computational costs and memory requirements of the scattering transform also increase.

    Note

    The number of paths in the JTFS network does not depend on the frequency oversampling factor. This is different from waveletScattering. The value of the OversamplingFactor property in a wavelet scattering network affects the number of paths in the network.

    Data Types: single | double

    Data format of x, specified as a character vector or string scalar. This name-value argument is valid only if x is an unformatted dlarray. If x is not a dlarray, the function ignores the DataFormat argument.

    Each character in this argument must be one of these labels:

    • "C" — Channel

    • "B" — Batch observations

    • "T" — Time

    DataFormat can be any permutation of "CBT".

    Data Types: char | string

    Output Arguments

    collapse all

    Joint time-frequency scattering transform, returned as a dictionary object with these keys:

    • "S1FreqLowpass" — First-order time scattering coefficients filtered with the frequency lowpass filter

    • "S1SpinUpFreqLowpass" — First-order time scattering coefficients with the spin-up frequency wavelets

    • "SpinUp" — Second-order time scattering coefficients with spin-up wavelets

    • "SpinDown" — Second-order time scattering coefficients with spin-down wavelets

    • "U2JointLowpass" — Second-order time scattering coefficients filtered with joint lowpass filters

    For more information, see Joint Time-Frequency Scattering Coefficients.

    All dictionary values are in path-by-frequency-by-time-by-channel-by-batch format.

    If x is a formatted or unformatted dlarray, every dictionary value is an unformatted dlarray.

    Metadata for each coefficient key in outCFS, returned as a cell array of tables. All tables have these variables:

    • type — Coefficient key.

    • path — Two-column variable indicating the coefficient path. The first column is the index of the frequency wavelet, and the second column is the index of the second-order time wavelet. A value of –1 indicates the lowpass filter.

    • spin — Wavelet spin. A value of 1 indicates a spin-up wavelet, and –1 indicates a spin-down wavelet. A value of 0 indicates that scatteringTransform did not use a spin-up or spin-down wavelet to compute those coefficients.

    • log2stride — Two-column variable indicating how much scatteringTransform downsamples in frequency (first column) and time (second column) after applying the lowpass filters. Values are on a base-2 logarithmic scale.

    • log2dsfactor — Downsampling factors in frequency and second-order time. Values are on a base-2 logarithmic scale.

      • If type is "SpinUp" or "SpinDown", then log2dsfactor is a two-column variable indicating the downsampling factors in frequency (first column) and second-order time (second column).

      • If type is "S1SpinUpFreqLowpass", then log2dsfactor is a single-column variable indicating the downsampling factor in frequency.

      • If type is "U2JointLowpass", then log2dsfactor is a single-column variable indicating the downsampling factor in second-order time.

      • This table variable is not applicable when type is "S1FreqLowpass".

    More About

    collapse all

    Joint Time-Frequency Scattering Coefficients

    The joint time-frequency scattering (JTFS) transform is used to extract time-frequency features from a signal that are invariant to shifts and deformations in time and frequency. Compute the JTFS transform by first convolving the signal in time with wavelets followed by pointwise modulus nonlinearities. Then filter that result along frequency with frequential wavelets [1][2].

    Let:

    • x denote the signal.

    • ψt(1) and ψt(2) denote the time wavelets in the first- and second-order filter banks, respectively.

    • ψf,s denote the frequential wavelets of spin s. If s = 1, these are the spin-up wavelets. If s = –1, these are the spin-down wavelets.

    • φt and φf denote the time and frequential lowpass filters, respectively.

    Then the JTFS coefficients are defined as:

    • "S1FreqLowpass"|xψt(1)|φtφf

    • "S1SpinUpFreqLowpass"||xψt(1)|φtψf,s|φf for s = 1

    • "SpinUp"||xψt(1)|ψt(2)ψf,s|φtφf for s = 1

    • "SpinDown"||xψt(1)|ψt(2)ψf,s|φtφf for s = –1

    • "U2JointLowpass"||xψt(1)|ψt(2)|φtφf

    For more information, see Joint Time-Frequency Scattering.

    References

    [1] Andén, Joakim, Vincent Lostanlen, and Stéphane Mallat. “Joint Time–Frequency Scattering.” IEEE Transactions on Signal Processing 67, no. 14 (July 15, 2019): 3704–18.https://doi.org/10.1109/TSP.2019.2918992

    [2] Lostanlen, Vincent, Christian El-Hajj, Mathias Rossignol, Grégoire Lafay, Joakim Andén, and Mathieu Lagrange. “Time–Frequency Scattering Accurately Models Auditory Similarities between Instrumental Playing Techniques.” EURASIP Journal on Audio, Speech, and Music Processing 2021, no. 1 (December 2021): 3. https://doi.org/10.1186/s13636-020-00187-z

    [3] Mallat, Stéphane. “Group Invariant Scattering.” Communications on Pure and Applied Mathematics 65, no. 10 (October 2012): 1331–98. https://doi.org/10.1002/cpa.21413

    Extended Capabilities

    Version History

    Introduced in R2024b