アテンションモデルに空白を入れたくないので、単に
:math:`\sum_{i=1}^n \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i`
を、実際の文の長さ :math:`l \leq n` に応じて
:math:`\sum_{i=1}^l \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i`
に制限すればよいのである。これは非常に一般的な問題なので、名前が付いている。それが
*マスク付きソフトマックス演算* である。
実装してみよう。実際の実装では、\ :math:`i > l` の :math:`\mathbf{v}_i`
の値をゼロにすることで、わずかにごまかしている。さらに、勾配や値への寄与を実質的に消すために、アテンション重みを
:math:`-10^{6}`
のような非常に大きな負の値に設定する。これは、線形代数のカーネルや演算子がGPU向けに強く最適化されており、条件分岐(if
then
else)を含むコードにするよりも、多少計算を無駄にしても速いからである。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.size(1)
mask = torch.arange((maxlen), dtype=torch.float32,
device=X.device)[None, :] < valid_len[:, None]
X[~mask] = value
return X
if valid_lens is None:
return nn.functional.softmax(X, dim=-1)
else:
shape = X.shape
if valid_lens.dim() == 1:
valid_lens = torch.repeat_interleave(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.functional.softmax(X.reshape(shape), dim=-1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
if valid_lens is None:
return npx.softmax(X)
else:
shape = X.shape
if valid_lens.ndim == 1:
valid_lens = valid_lens.repeat(shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,
value=-1e6, axis=1)
return npx.softmax(X).reshape(shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.shape[1]
mask = jnp.arange((maxlen),
dtype=jnp.float32)[None, :] < valid_len[:, None]
return jnp.where(mask, X, value)
if valid_lens is None:
return nn.softmax(X, axis=-1)
else:
shape = X.shape
if valid_lens.ndim == 1:
valid_lens = jnp.repeat(valid_lens, shape[1])
else:
valid_lens = valid_lens.reshape(-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
return nn.softmax(X.reshape(shape), axis=-1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def masked_softmax(X, valid_lens): #@save
"""Perform softmax operation by masking elements on the last axis."""
# X: 3D tensor, valid_lens: 1D or 2D tensor
def _sequence_mask(X, valid_len, value=0):
maxlen = X.shape[1]
mask = tf.range(start=0, limit=maxlen, dtype=tf.float32)[
None, :] < tf.cast(valid_len[:, None], dtype=tf.float32)
if len(X.shape) == 3:
return tf.where(tf.expand_dims(mask, axis=-1), X, value)
else:
return tf.where(mask, X, value)
if valid_lens is None:
return tf.nn.softmax(X, axis=-1)
else:
shape = X.shape
if len(valid_lens.shape) == 1:
valid_lens = tf.repeat(valid_lens, repeats=shape[1])
else:
valid_lens = tf.reshape(valid_lens, shape=-1)
# On the last axis, replace masked elements with a very large negative
# value, whose exponentiation outputs 0
X = _sequence_mask(tf.reshape(X, shape=(-1, shape[-1])), valid_lens,
value=-1e6)
return tf.nn.softmax(tf.reshape(X, shape=shape), axis=-1)
.. raw:: html
.. raw:: html
この関数がどのように動作するかを示すため、サイズが :math:`2 \times 4`
の2つの例からなるミニバッチを考え、それぞれの有効長が :math:`2` と
:math:`3`
であるとする。マスク付きソフトマックス演算の結果、各ベクトルの組について有効長を超える値はすべてゼロとしてマスクされる。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[[0.4639, 0.5361, 0.0000, 0.0000],
[0.3481, 0.6519, 0.0000, 0.0000]],
[[0.4261, 0.1981, 0.3759, 0.0000],
[0.4635, 0.2937, 0.2428, 0.0000]]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(np.random.uniform(size=(2, 2, 4)), d2l.tensor([2, 3]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)), jnp.array([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[[0.54144967, 0.4585504 , 0. , 0. ],
[0.54905266, 0.45094734, 0. , 0. ]],
[[0.23869859, 0.52262574, 0.2386757 , 0. ],
[0.4002735 , 0.3469047 , 0.25282183, 0. ]]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(tf.random.uniform(shape=(2, 2, 4)), tf.constant([2, 3]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
各例の2つのベクトルそれぞれに対して有効長をより細かく指定したい場合は、単に2次元の有効長テンソルを使う。すると次のようになる。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(torch.rand(2, 2, 4), d2l.tensor([[1, 3], [2, 4]]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(np.random.uniform(size=(2, 2, 4)),
d2l.tensor([[1, 3], [2, 4]]))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)),
jnp.array([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[[1. , 0. , 0. , 0. ],
[0.29248205, 0.37135315, 0.3361648 , 0. ]],
[[0.38551602, 0.61448395, 0. , 0. ],
[0.17129263, 0.21547154, 0.25191936, 0.36131644]]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
masked_softmax(tf.random.uniform((2, 2, 4)), tf.constant([[1, 3], [2, 4]]))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
.. _subsec_batch_dot:
バッチ行列積
~~~~~~~~~~~~
もう1つよく使われる演算は、行列のバッチ同士を掛け合わせることである。これは、クエリ、キー、値のミニバッチを扱うときに便利である。より具体的には、次を仮定しよう。
.. math::
\mathbf{Q} = [\mathbf{Q}_1, \mathbf{Q}_2, \ldots, \mathbf{Q}_n] \in \mathbb{R}^{n \times a \times b}, \\
\mathbf{K} = [\mathbf{K}_1, \mathbf{K}_2, \ldots, \mathbf{K}_n] \in \mathbb{R}^{n \times b \times c}.
このとき、バッチ行列積(BMM)は要素ごとの積を計算する。
.. math:: \textrm{BMM}(\mathbf{Q}, \mathbf{K}) = [\mathbf{Q}_1 \mathbf{K}_1, \mathbf{Q}_2 \mathbf{K}_2, \ldots, \mathbf{Q}_n \mathbf{K}_n] \in \mathbb{R}^{n \times a \times c}.
:label: eq_batch-matrix-mul
深層学習フレームワークでこれを見てみよう。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(npx.batch_dot(Q, K), (2, 3, 6))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(jax.lax.batch_matmul(Q, K), (2, 3, 6))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(tf.matmul(Q, K).numpy(), (2, 3, 6))
.. raw:: html
.. raw:: html
スケールド内積アテンション
--------------------------
:eq:`eq_dot_product_attention` で導入した内積アテンションに戻ろう。
一般に、クエリとキーの両方が同じベクトル長、たとえば :math:`d`
を持つことが必要であるが、これは :math:`\mathbf{q}^\top \mathbf{k}` を
:math:`\mathbf{q}^\top \mathbf{M} \mathbf{k}`
に置き換え、\ :math:`\mathbf{M}`
を両空間を変換するために適切に選んだ行列とすることで簡単に対処できる。ここでは、次元が一致していると仮定する。
実際には、効率のためにミニバッチを考えることが多く、たとえば :math:`n`
個のクエリと :math:`m`
個のキー・値ペアに対してアテンションを計算する。このとき、クエリとキーの長さは
:math:`d`\ 、値の長さは :math:`v` である。したがって、クエリ
:math:`\mathbf Q\in\mathbb R^{n\times d}`\ 、キー
:math:`\mathbf K\in\mathbb R^{m\times d}`\ 、値
:math:`\mathbf V\in\mathbb R^{m\times v}`
に対するスケールド内積アテンションは次のように書ける。
.. math:: \mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.
:label: eq_softmax_QK_V
これをミニバッチに適用する際には、:eq:`eq_batch-matrix-mul`
で導入したバッチ行列積が必要になることに注意されたい。以下のスケールド内積アテンションの実装では、モデル正則化のためにドロップアウトを使う。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DotProductAttention(nn.Module): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Swap the last two dimensions of keys with keys.transpose(1, 2)
scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return torch.bmm(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DotProductAttention(nn.Block): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def forward(self, queries, keys, values, valid_lens=None):
d = queries.shape[-1]
# Set transpose_b=True to swap the last two dimensions of keys
scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
self.attention_weights = masked_softmax(scores, valid_lens)
return npx.batch_dot(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DotProductAttention(nn.Module): #@save
"""Scaled dot product attention."""
dropout: float
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
@nn.compact
def __call__(self, queries, keys, values, valid_lens=None,
training=False):
d = queries.shape[-1]
# Swap the last two dimensions of keys with keys.swapaxes(1, 2)
scores = queries@(keys.swapaxes(1, 2)) / math.sqrt(d)
attention_weights = masked_softmax(scores, valid_lens)
dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
return dropout_layer(attention_weights)@values, attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DotProductAttention(tf.keras.layers.Layer): #@save
"""Scaled dot product attention."""
def __init__(self, dropout):
super().__init__()
self.dropout = tf.keras.layers.Dropout(dropout)
# Shape of queries: (batch_size, no. of queries, d)
# Shape of keys: (batch_size, no. of key-value pairs, d)
# Shape of values: (batch_size, no. of key-value pairs, value dimension)
# Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
def call(self, queries, keys, values, valid_lens=None, **kwargs):
d = queries.shape[-1]
scores = tf.matmul(queries, keys, transpose_b=True)/tf.math.sqrt(
tf.cast(d, dtype=tf.float32))
self.attention_weights = masked_softmax(scores, valid_lens)
return tf.matmul(self.dropout(self.attention_weights, **kwargs), values)
.. raw:: html
.. raw:: html
``DotProductAttention``
クラスがどのように動作するかを示すために、先ほどの加法アテンションの玩具例と同じキー、値、有効長を使う。この例では、ミニバッチサイズを
:math:`2`\ 、キーと値の総数を :math:`10`\ 、値の次元を :math:`4`
と仮定する。さらに、各観測の有効長はそれぞれ :math:`2` と :math:`6`
とする。すると、出力は :math:`2 \times 1 \times 4`
のテンソル、つまりミニバッチの各例につき1行になるはずである。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = d2l.normal(0, 1, (2, 1, 2))
keys = d2l.normal(0, 1, (2, 10, 2))
values = d2l.normal(0, 1, (2, 10, 4))
valid_lens = d2l.tensor([2, 6])
attention = DotProductAttention(dropout=0.5)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = d2l.normal(0, 1, (2, 1, 2))
keys = d2l.normal(0, 1, (2, 10, 2))
values = d2l.normal(0, 1, (2, 10, 4))
valid_lens = d2l.tensor([2, 6])
attention = DotProductAttention(dropout=0.5)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = jax.random.normal(d2l.get_key(), (2, 1, 2))
keys = jax.random.normal(d2l.get_key(), (2, 10, 2))
values = jax.random.normal(d2l.get_key(), (2, 10, 4))
valid_lens = d2l.tensor([2, 6])
attention = DotProductAttention(dropout=0.5)
(output, attention_weights), params = attention.init_with_output(
d2l.get_key(), queries, keys, values, valid_lens)
print(output)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = tf.random.normal(shape=(2, 1, 2))
keys = tf.random.normal(shape=(2, 10, 2))
values = tf.random.normal(shape=(2, 10, 4))
valid_lens = tf.constant([2, 6])
attention = DotProductAttention(dropout=0.5)
d2l.check_shape(attention(queries, keys, values, valid_lens, training=False),
(2, 1, 4))
.. raw:: html
.. raw:: html
アテンション重みが実際に、それぞれ第2列と第6列を超える部分で消えているか確認してみよう(有効長を
:math:`2` と :math:`6` に設定しているためである)。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(d2l.reshape(attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. raw:: html
.. raw:: html
.. _subsec_additive-attention:
加法アテンション
----------------
クエリ :math:`\mathbf{q}` とキー :math:`\mathbf{k}`
が異なる次元のベクトルであるとき、\ :math:`\mathbf{q}^\top \mathbf{M} \mathbf{k}`
によって次元の不一致を行列で調整する方法もあれば、スコアリング関数として加法アテンションを使う方法もある。もう1つの利点は、その名の通りアテンションが加法的であることである。これにより、わずかな計算量の節約が可能になる。クエリ
:math:`\mathbf{q} \in \mathbb{R}^q` とキー
:math:`\mathbf{k} \in \mathbb{R}^k` に対して、\ *加法アテンション*
のスコアリング関数 :cite:`Bahdanau.Cho.Bengio.2014`
は次のように与えられる。
.. math:: a(\mathbf q, \mathbf k) = \mathbf w_v^\top \textrm{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},
:label: eq_additive-attn
ここで、\ :math:`\mathbf W_q\in\mathbb R^{h\times q}`\ 、\ :math:`\mathbf W_k\in\mathbb R^{h\times k}`\ 、\ :math:`\mathbf w_v\in\mathbb R^{h}`
は学習可能なパラメータである。この項をソフトマックスに入力して、非負性と正規化の両方を保証する。:eq:`eq_additive-attn`
の同値な解釈として、クエリとキーを連結し、1つの隠れ層を持つMLPに入力しているとみなすこともできる。活性化関数として
:math:`\tanh`
を使い、バイアス項を無効にして、加法アテンションを次のように実装する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AdditiveAttention(nn.Module): #@save
"""Additive attention."""
def __init__(self, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
self.W_k = nn.LazyLinear(num_hiddens, bias=False)
self.W_q = nn.LazyLinear(num_hiddens, bias=False)
self.w_v = nn.LazyLinear(1, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = queries.unsqueeze(2) + keys.unsqueeze(1)
features = torch.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return torch.bmm(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AdditiveAttention(nn.Block): #@save
"""Additive attention."""
def __init__(self, num_hiddens, dropout, **kwargs):
super(AdditiveAttention, self).__init__(**kwargs)
# Use flatten=False to only transform the last axis so that the
# shapes for the other axes are kept the same
self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)
self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)
self.w_v = nn.Dense(1, use_bias=False, flatten=False)
self.dropout = nn.Dropout(dropout)
def forward(self, queries, keys, values, valid_lens):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1,
# no. of key-value pairs, num_hiddens). Sum them up with
# broadcasting
features = np.expand_dims(queries, axis=2) + np.expand_dims(
keys, axis=1)
features = np.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores:
# (batch_size, no. of queries, no. of key-value pairs)
scores = np.squeeze(self.w_v(features), axis=-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return npx.batch_dot(self.dropout(self.attention_weights), values)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AdditiveAttention(nn.Module): #@save
num_hiddens: int
dropout: float
def setup(self):
self.W_k = nn.Dense(self.num_hiddens, use_bias=False)
self.W_q = nn.Dense(self.num_hiddens, use_bias=False)
self.w_v = nn.Dense(1, use_bias=False)
@nn.compact
def __call__(self, queries, keys, values, valid_lens, training=False):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = jnp.expand_dims(queries, axis=2) + jnp.expand_dims(keys, axis=1)
features = nn.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = self.w_v(features).squeeze(-1)
attention_weights = masked_softmax(scores, valid_lens)
dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return dropout_layer(attention_weights)@values, attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AdditiveAttention(tf.keras.layers.Layer): #@save
"""Additive attention."""
def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
super().__init__(**kwargs)
self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=False)
self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=False)
self.w_v = tf.keras.layers.Dense(1, use_bias=False)
self.dropout = tf.keras.layers.Dropout(dropout)
def call(self, queries, keys, values, valid_lens, **kwargs):
queries, keys = self.W_q(queries), self.W_k(keys)
# After dimension expansion, shape of queries: (batch_size, no. of
# queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
# key-value pairs, num_hiddens). Sum them up with broadcasting
features = tf.expand_dims(queries, axis=2) + tf.expand_dims(
keys, axis=1)
features = tf.nn.tanh(features)
# There is only one output of self.w_v, so we remove the last
# one-dimensional entry from the shape. Shape of scores: (batch_size,
# no. of queries, no. of key-value pairs)
scores = tf.squeeze(self.w_v(features), axis=-1)
self.attention_weights = masked_softmax(scores, valid_lens)
# Shape of values: (batch_size, no. of key-value pairs, value
# dimension)
return tf.matmul(self.dropout(
self.attention_weights, **kwargs), values)
.. raw:: html
.. raw:: html
``AdditiveAttention``
がどのように動作するかを見てみよう。玩具例では、クエリ、キー、値のサイズをそれぞれ
:math:`(2, 1, 20)`\ 、\ :math:`(2, 10, 2)`\ 、\ :math:`(2, 10, 4)`
とする。これは ``DotProductAttention``
のときの選択と同じであるが、今回はクエリが20次元である点が異なる。同様に、ミニバッチ内の系列の有効長として
:math:`(2, 6)` を選ぶ。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = d2l.normal(0, 1, (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = d2l.normal(0, 1, (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = jax.random.normal(d2l.get_key(), (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
(output, attention_weights), params = attention.init_with_output(
d2l.get_key(), queries, keys, values, valid_lens)
print(output)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[[[-0.80551296 -0.46030337 -0.02557675 1.0253853 ]]
[[ 0.25692716 0.2782622 -0.0186431 0.5796091 ]]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
queries = tf.random.normal(shape=(2, 1, 20))
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
dropout=0.1)
d2l.check_shape(attention(queries, keys, values, valid_lens, training=False),
(2, 1, 4))
.. raw:: html
.. raw:: html
アテンション関数を確認すると、\ ``DotProductAttention``
の場合と質的にかなり似た振る舞いが見られる。つまり、選択した有効長
:math:`(2, 6)` の範囲内の項だけが非ゼロである。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(d2l.reshape(attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
xlabel='Keys', ylabel='Queries')
.. raw:: html
.. raw:: html
まとめ
------
この節では、2つの主要なアテンションのスコアリング関数、すなわち内積アテンションと加法アテンションを導入した。これらは、長さが可変な系列全体を集約するための有効な手段である。特に、内積アテンションは現代のTransformerアーキテクチャの中核をなしている。クエリとキーが異なる長さのベクトルである場合には、代わりに加法アテンションのスコアリング関数を使うことができる。これらの層を最適化することは、近年の進歩の重要な分野の1つである。たとえば、\ `NVIDIA
の Transformer
Library `__
や Megatron :cite:`shoeybi2019megatron`
は、効率的なアテンション機構の変種に大きく依存している。後の節でTransformerを学ぶ際に、これについてさらに詳しく見ていく。
演習
----
1. ``DotProductAttention``
のコードを修正して、距離ベースのアテンションを実装せよ。効率的な実装には、キーの二乗ノルム
:math:`\|\mathbf{k}_i\|^2` だけが必要であることに注意せよ。
2. 行列を用いて次元を調整することで、異なる次元のクエリとキーを扱えるように内積アテンションを修正せよ。
3. 計算コストは、キー、クエリ、値の次元およびその個数に対してどのようにスケールするか。メモリ帯域幅の要件についてはどうか。