Skip to content

MHAttention

source

MHAttention(
   query, key, value, n_units = None, n_heads = 1, attention_fn = tf.nn.softmax,
   causality = False, attention_dropout = 0.0, regularized = False, name = 'attention',
   share_state_with = None
)

Scaled Dot Product MultiHead Attention Layer

(Q,K,V): Encodes representation of the input as a set of key-value pairs, (K,V), both of dimension n (input sequence length); in the context of sequence-to-sequence models, both keys and values are the encoder hidden states. In the decoder, the previous output is compressed into a query (Q of dimension m)

Args

  • n_units : output number of units, each attention head has n_units // n_head units query: key: value:

Methods:

.compute_shape

source

.compute_shape()

.init_state

source

.init_state()

.compute

source

.compute(
   *input_tensors
)

.reuse_with

source

.reuse_with(
   query, key, value, regularized = None, causality = None, name = None
)