Layer
Layer(
inputs, n_units, shape = None, dtype = None, name = 'layer', **kwargs
)
Layer Base Class
Passing Attributes
All keyword attributes passed to this constructor will be set as instance attributes so a common case for the implementing class might be:
class CustomLayer(Layer):
def __init__(layer, n_units, param=1):
# for the linter
self.param = param
# ...
super().__init__(inputs=layer, n_units=n_units, param=value)
in = tx.Input(10)
y = CustomLayer(in, 4, param=2)
assert y.param == 2
Attributes
- inputs (
Sequence[Layer]
) : a list of input nodes for the current layer - n_units : the number of units for the current layer (last dim)
- name (
str
) : name to be used for the layer scope - config (
LayerConfig
) : a layer configuration with the arguments used in the current layer instance - scoped_name (
str
) : layer full scope name
Args
- inputs (
Sequence[Layer]
) : a single layer,a list of input layers, or None if no inputs are required - n_units (
int
) : dimension of input vector (dimension of columns in case batch_size != None - dtype (
DType
) : dtype for the current layer output - shape (
TensorShape
) : output shape. If not None overridescompute_shape
- name (
str
) : layer name (used to nam the placeholder) - kwargs (
Any
) : other keyword args to be set as instance attributes
Methods:
.compute_shape
.compute_shape()
called before init_state
Returns
- shape (
tf.TensorShape
) : best guess for the output shape of the layer
.init_state
.init_state()
init_state meant to be overriden in subclasses
Creates an empty LayerState
object
Overriding init_state()
Classes implementing Layer
should override this method
def init_state(self):
state = super().init_state()
# or state = LayerState()
state.var1 = var1
state.var2 = var2
return state
Layer
will take this state object and add var1
and var2
to attributes.
Returns
- state (
LayerState
) : current layer state object
.compute
.compute(
*args
)
.as_function
.as_function(
name = 'layer_function', compile = False
)
returns a python function of a Tensorflow compiled graph as a callable
Note
This returns the entire graph as a function that terminates on
this layer. If you want the function for this layer alone just
get the tf.function(layer.compute)
Args
- name (
str
) : function name to be returned - compile (
bool
) : if True, returns aTensorflow
compile graph as a callable else, returns a python function.
Returns
- fn (
Callable
) : either a Tensorflow static graph or a python callable function.
.reuse_with
.reuse_with(
*layers, **kwargs
)