MHAttention
MHAttention(
query, key, value, n_units = None, n_heads = 1, attention_fn = tf.nn.softmax,
causality = False, attention_dropout = 0.0, regularized = False, name = 'attention',
share_state_with = None
)
Scaled Dot Product MultiHead Attention Layer
(Q,K,V): Encodes representation of the input as a set of key-value pairs, (K,V), both of dimension n (input sequence length); in the context of sequence-to-sequence models, both keys and values are the encoder hidden states. In the decoder, the previous output is compressed into a query (Q of dimension m)
Args
- n_units : output number of units, each attention head has
n_units // n_head
units query: key: value:
Methods:
.compute_shape
.compute_shape()
.init_state
.init_state()
.compute
.compute(
*input_tensors
)
.reuse_with
.reuse_with(
query, key, value, regularized = None, causality = None, name = None
)