How to conditionally define a learnable property?
    5 views (last 30 days)
  
       Show older comments
    
Hello,
Is it possible, and if so - how, to define a learnable property based on a condition checked in the layer constructor?
What I mean is that currently I define all learnable properties in the properties(Learnable) block of layer definition. I would like to have some properties to exist only if a certain condition is set, checked in the constructor, e.g.,
properties(Learnable)
end
function self = constructor(cond)
if cond
    self.condprop = addprop(self,'condprop');
    self.condprop.Learnable = true;
end
end
I tried using dynamic properties (dynamicprops), but this doesn't work because it inherits from the handle class, which the other super-classes of the layer do not.
Thx
0 Comments
Answers (3)
  Matt J
      
      
 on 23 Mar 2023
        
      Edited: Matt J
      
      
 on 23 Mar 2023
  
      If you really must have a layer with different properties,  based on a conditional flag setting, it would probably be better to just replace the layer in the network with a different class of layer, which can also be done conditionally.
loc=contains( {Layers.Description}, something);
if cond
   Layers(loc)=newlayer; 
end
  Matt J
      
      
 on 23 Mar 2023
        An indirect solution would be to have an additional property Wknown that lets you provide an over-riding prior value for a particular learnable property W. 
classdef myLayer < nnet.layer.Layer % ...
    properties
        Wknown=[] ; %over-ride for learnable property W
    end
    properties(Learnable)
      W
    end
    function layer = myLayer(wknown)
        if nargin
         layer.Wknown=wknown;
        end
    end
When a non-empty value for Wknown isn't provided in the constructor, your forward() and backward() method will treat the learnable parameter W in the normal way. When Wknown is provided, you write the forward() method to use Wknown instead of W in creating the output prediction, and you write the backward method to return dLdW=0.
   methods
       function [Z,state,memory] = forward(layer,X)
           cond=isempty(layer.Wknown)
           if cond
               W=layer.W;
               Z=...
           else
              W=layer.Wknown;
              Z=...
           end
       end
       function [dLdX,dLdW,dLdSin] = backward(layer,X,Z,dLdZ,dLdSout,memory)
            cond=isempty(layer.Wknown)
           if cond
               W=layer.W;
               dLdW=...
           else
              W=layer.Wknown;
              dLdW=0;
           end          
       end
   end
end
0 Comments
See Also
Categories
				Find more on Deep Learning Toolbox in Help Center and File Exchange
			
	Community Treasure Hunt
Find the treasures in MATLAB Central and discover how the community can help you!
Start Hunting!

