Skip to content



   inputs, n_units, shape = None, dtype = None, name = 'layer', **kwargs

Layer Base Class

Passing Attributes

All keyword attributes passed to this constructor will be set as instance attributes so a common case for the implementing class might be:

class CustomLayer(Layer):
    def __init__(layer, n_units, param=1):
        # for the linter
        self.param = param
        # ...
        super().__init__(inputs=layer, n_units=n_units, param=value)

in = tx.Input(10)
y = CustomLayer(in, 4, param=2)
assert y.param == 2


  • inputs (Sequence[Layer]) : a list of input nodes for the current layer
  • n_units : the number of units for the current layer (last dim)
  • name (str) : name to be used for the layer scope
  • config (LayerConfig) : a layer configuration with the arguments used in the current layer instance
  • scoped_name (str) : layer full scope name


  • inputs (Sequence[Layer]) : a single layer,a list of input layers, or None if no inputs are required
  • n_units (int) : dimension of input vector (dimension of columns in case batch_size != None
  • dtype (DType) : dtype for the current layer output
  • shape (TensorShape) : output shape. If not None overrides compute_shape
  • name (str) : layer name (used to nam the placeholder)
  • kwargs (Any) : other keyword args to be set as instance attributes





called before init_state


  • shape (tf.TensorShape) : best guess for the output shape of the layer




init_state meant to be overriden in subclasses

Creates an empty LayerState object

Overriding init_state()

Classes implementing Layer should override this method

def init_state(self):
    state = super().init_state()
    # or state = LayerState()
    state.var1 = var1
    state.var2 = var2
    return state

Layer will take this state object and add var1 and var2 to attributes.


  • state (LayerState) : current layer state object






   name = 'layer_function', compile = False

returns a python function of a Tensorflow compiled graph as a callable


This returns the entire graph as a function that terminates on this layer. If you want the function for this layer alone just get the tf.function(layer.compute)


  • name (str) : function name to be returned
  • compile (bool) : if True, returns a Tensorflow compile graph as a callable else, returns a python function.


  • fn (Callable) : either a Tensorflow static graph or a python callable function.



   *layers, **kwargs