11.3. アテンションのスコアリング関数

11.2 章 では、クエリとキーの相互作用をモデル化するために、ガウスカーネルを含むいくつかの距離ベースのカーネルを用いた。ところが、距離関数は内積よりも計算コストがやや高いことがわかっている。そのため、非負のアテンション重みを保証するソフトマックス演算と組み合わせる場合、計算がより簡単な アテンションのスコアリング関数 \(a\) に多くの工夫が注がれてきた。これは (11.1.3)図 11.3.1 に現れる。

../_images/attention-output.svg

図 11.3.1 アテンションプーリングの出力を値の重み付き平均として計算する。重みはアテンションのスコアリング関数 \(\mathit{a}\) とソフトマックス演算で求める。

from d2l import torch as d2l
import math
import torch
from torch import nn
import math
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()
from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax
import math
from d2l import tensorflow as d2l
import tensorflow as tf

11.3.1. 内積アテンション

まず、ガウスカーネルから得られるアテンション関数(指数関数を除く)を少し見直してみよう。

(11.3.1)\[a(\mathbf{q}, \mathbf{k}_i) = -\frac{1}{2} \|\mathbf{q} - \mathbf{k}_i\|^2 = \mathbf{q}^\top \mathbf{k}_i -\frac{1}{2} \|\mathbf{k}_i\|^2 -\frac{1}{2} \|\mathbf{q}\|^2.\]

まず、最後の項は \(\mathbf{q}\) のみに依存することに注意されたい。したがって、すべての \((\mathbf{q}, \mathbf{k}_i)\) の組に対して同じ値である。(11.1.3) で行うように、アテンション重みを \(1\) に正規化すると、この項は完全に消える。次に、バッチ正規化と層正規化(後で説明する)のどちらも、十分に有界で、しばしば一定のノルム \(\|\mathbf{k}_i\|\) を持つ活性化をもたらすことにも注意されたい。たとえば、キー \(\mathbf{k}_i\) が層正規化によって生成されている場合がそうである。したがって、結果を大きく変えることなく、この項を \(a\) の定義から取り除くことができる。

最後に、指数関数の引数のオーダーを適切に制御する必要がある。クエリ \(\mathbf{q} \in \mathbb{R}^d\) とキー \(\mathbf{k}_i \in \mathbb{R}^d\) のすべての要素が、平均0・分散1の独立同分布な乱数であると仮定しよう。両ベクトルの内積は平均0、分散 \(d\) になる。ベクトル長に依らず内積の分散を \(1\) に保つために、スケールド内積アテンション のスコアリング関数を用いる。つまり、内積を \(1/\sqrt{d}\) で再スケールする。こうして、Transformer などで使われる最初の一般的なアテンション関数に到達する (Vaswani et al., 2017):

(11.3.2)\[a(\mathbf{q}, \mathbf{k}_i) = \mathbf{q}^\top \mathbf{k}_i / \sqrt{d}.\]

アテンション重み \(\alpha\) は依然として正規化が必要であることに注意されたい。これを (11.1.3) によりさらに簡単にするため、ソフトマックス演算を用いる。

(11.3.3)\[\alpha(\mathbf{q}, \mathbf{k}_i) = \mathrm{softmax}(a(\mathbf{q}, \mathbf{k}_i)) = \frac{\exp(\mathbf{q}^\top \mathbf{k}_i / \sqrt{d})}{\sum_{j=1} \exp(\mathbf{q}^\top \mathbf{k}_j / \sqrt{d})}.\]

実際、広く使われているアテンション機構はすべてソフトマックスを用いているため、この章の残りではそれに限定する。

11.3.2. 便利な関数

アテンション機構を効率よく実装するために、いくつかの関数が必要である。これには、長さが可変な文字列を扱うためのツール(自然言語処理で一般的)と、ミニバッチ上で効率よく評価するためのツール(バッチ行列積)が含まれる。

11.3.2.1. マスク付きソフトマックス演算

アテンション機構の最も一般的な応用の1つは系列モデルである。したがって、長さの異なる系列を扱える必要がある。場合によっては、そのような系列が同じミニバッチに入ることがあり、短い系列にはダミートークンによるパディングが必要になる(例は 10.5 章 を参照)。これらの特別なトークンは意味を持たない。たとえば、次の3つの文があるとする。

Dive  into  Deep    Learning
Learn to    code    <blank>
Hello world <blank> <blank>

アテンションモデルに空白を入れたくないので、単に \(\sum_{i=1}^n \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i\) を、実際の文の長さ \(l \leq n\) に応じて \(\sum_{i=1}^l \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i\) に制限すればよいのである。これは非常に一般的な問題なので、名前が付いている。それが マスク付きソフトマックス演算 である。

実装してみよう。実際の実装では、\(i > l\)\(\mathbf{v}_i\) の値をゼロにすることで、わずかにごまかしている。さらに、勾配や値への寄与を実質的に消すために、アテンション重みを \(-10^{6}\) のような非常に大きな負の値に設定する。これは、線形代数のカーネルや演算子がGPU向けに強く最適化されており、条件分岐(if then else)を含むコードにするよりも、多少計算を無駄にしても速いからである。

def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)
def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    if valid_lens is None:
        return npx.softmax(X)
    else:
        shape = X.shape
        if valid_lens.ndim == 1:
            valid_lens = valid_lens.repeat(shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = npx.sequence_mask(X.reshape(-1, shape[-1]), valid_lens, True,
                              value=-1e6, axis=1)
        return npx.softmax(X).reshape(shape)
def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.shape[1]
        mask = jnp.arange((maxlen),
                          dtype=jnp.float32)[None, :] < valid_len[:, None]
        return jnp.where(mask, X, value)

    if valid_lens is None:
        return nn.softmax(X, axis=-1)
    else:
        shape = X.shape
        if valid_lens.ndim == 1:
            valid_lens = jnp.repeat(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.softmax(X.reshape(shape), axis=-1)
def masked_softmax(X, valid_lens):  #@save
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.shape[1]
        mask = tf.range(start=0, limit=maxlen, dtype=tf.float32)[
            None, :] < tf.cast(valid_len[:, None], dtype=tf.float32)

        if len(X.shape) == 3:
            return tf.where(tf.expand_dims(mask, axis=-1), X, value)
        else:
            return tf.where(mask, X, value)

    if valid_lens is None:
        return tf.nn.softmax(X, axis=-1)
    else:
        shape = X.shape
        if len(valid_lens.shape) == 1:
            valid_lens = tf.repeat(valid_lens, repeats=shape[1])

        else:
            valid_lens = tf.reshape(valid_lens, shape=-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(tf.reshape(X, shape=(-1, shape[-1])), valid_lens,
                           value=-1e6)
        return tf.nn.softmax(tf.reshape(X, shape=shape), axis=-1)

この関数がどのように動作するかを示すため、サイズが \(2 \times 4\) の2つの例からなるミニバッチを考え、それぞれの有効長が \(2\)\(3\) であるとする。マスク付きソフトマックス演算の結果、各ベクトルの組について有効長を超える値はすべてゼロとしてマスクされる。

masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
tensor([[[0.4639, 0.5361, 0.0000, 0.0000],
         [0.3481, 0.6519, 0.0000, 0.0000]],

        [[0.4261, 0.1981, 0.3759, 0.0000],
         [0.4635, 0.2937, 0.2428, 0.0000]]])
masked_softmax(np.random.uniform(size=(2, 2, 4)), d2l.tensor([2, 3]))
masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)), jnp.array([2, 3]))
Array([[[0.54144967, 0.4585504 , 0.        , 0.        ],
        [0.54905266, 0.45094734, 0.        , 0.        ]],

       [[0.23869859, 0.52262574, 0.2386757 , 0.        ],
        [0.4002735 , 0.3469047 , 0.25282183, 0.        ]]], dtype=float32)
masked_softmax(tf.random.uniform(shape=(2, 2, 4)), tf.constant([2, 3]))
<tf.Tensor: shape=(2, 2, 4), dtype=float32, numpy=
array([[[0.43959942, 0.56040066, 0.        , 0.        ],
        [0.56975996, 0.43024006, 0.        , 0.        ]],

       [[0.3823931 , 0.305755  , 0.3118519 , 0.        ],
        [0.37867242, 0.25606048, 0.36526713, 0.        ]]], dtype=float32)>

各例の2つのベクトルそれぞれに対して有効長をより細かく指定したい場合は、単に2次元の有効長テンソルを使う。すると次のようになる。

masked_softmax(torch.rand(2, 2, 4), d2l.tensor([[1, 3], [2, 4]]))
masked_softmax(np.random.uniform(size=(2, 2, 4)),
               d2l.tensor([[1, 3], [2, 4]]))
masked_softmax(jax.random.uniform(d2l.get_key(), (2, 2, 4)),
               jnp.array([[1, 3], [2, 4]]))
Array([[[1.        , 0.        , 0.        , 0.        ],
        [0.29248205, 0.37135315, 0.3361648 , 0.        ]],

       [[0.38551602, 0.61448395, 0.        , 0.        ],
        [0.17129263, 0.21547154, 0.25191936, 0.36131644]]], dtype=float32)
masked_softmax(tf.random.uniform((2, 2, 4)), tf.constant([[1, 3], [2, 4]]))
<tf.Tensor: shape=(2, 2, 4), dtype=float32, numpy=
array([[[1.        , 0.        , 0.        , 0.        ],
        [0.4731181 , 0.3053249 , 0.221557  , 0.        ]],

       [[0.3913298 , 0.6086702 , 0.        , 0.        ],
        [0.16784328, 0.24103056, 0.23714985, 0.35397628]]], dtype=float32)>

11.3.2.2. バッチ行列積

もう1つよく使われる演算は、行列のバッチ同士を掛け合わせることである。これは、クエリ、キー、値のミニバッチを扱うときに便利である。より具体的には、次を仮定しよう。

(11.3.4)\[\begin{split}\mathbf{Q} = [\mathbf{Q}_1, \mathbf{Q}_2, \ldots, \mathbf{Q}_n] \in \mathbb{R}^{n \times a \times b}, \\ \mathbf{K} = [\mathbf{K}_1, \mathbf{K}_2, \ldots, \mathbf{K}_n] \in \mathbb{R}^{n \times b \times c}.\end{split}\]

このとき、バッチ行列積(BMM)は要素ごとの積を計算する。

(11.3.5)\[\textrm{BMM}(\mathbf{Q}, \mathbf{K}) = [\mathbf{Q}_1 \mathbf{K}_1, \mathbf{Q}_2 \mathbf{K}_2, \ldots, \mathbf{Q}_n \mathbf{K}_n] \in \mathbb{R}^{n \times a \times c}.\]

深層学習フレームワークでこれを見てみよう。

Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(torch.bmm(Q, K), (2, 3, 6))
Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(npx.batch_dot(Q, K), (2, 3, 6))
Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(jax.lax.batch_matmul(Q, K), (2, 3, 6))
Q = d2l.ones((2, 3, 4))
K = d2l.ones((2, 4, 6))
d2l.check_shape(tf.matmul(Q, K).numpy(), (2, 3, 6))

11.3.3. スケールド内積アテンション

(11.3.2) で導入した内積アテンションに戻ろう。 一般に、クエリとキーの両方が同じベクトル長、たとえば \(d\) を持つことが必要であるが、これは \(\mathbf{q}^\top \mathbf{k}\)\(\mathbf{q}^\top \mathbf{M} \mathbf{k}\) に置き換え、\(\mathbf{M}\) を両空間を変換するために適切に選んだ行列とすることで簡単に対処できる。ここでは、次元が一致していると仮定する。

実際には、効率のためにミニバッチを考えることが多く、たとえば \(n\) 個のクエリと \(m\) 個のキー・値ペアに対してアテンションを計算する。このとき、クエリとキーの長さは \(d\)、値の長さは \(v\) である。したがって、クエリ \(\mathbf Q\in\mathbb R^{n\times d}\)、キー \(\mathbf K\in\mathbb R^{m\times d}\)、値 \(\mathbf V\in\mathbb R^{m\times v}\) に対するスケールド内積アテンションは次のように書ける。

(11.3.6)\[\mathrm{softmax}\left(\frac{\mathbf Q \mathbf K^\top }{\sqrt{d}}\right) \mathbf V \in \mathbb{R}^{n\times v}.\]

これをミニバッチに適用する際には、(11.3.5) で導入したバッチ行列積が必要になることに注意されたい。以下のスケールド内積アテンションの実装では、モデル正則化のためにドロップアウトを使う。

class DotProductAttention(nn.Module):  #@save
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
class DotProductAttention(nn.Block):  #@save
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Set transpose_b=True to swap the last two dimensions of keys
        scores = npx.batch_dot(queries, keys, transpose_b=True) / math.sqrt(d)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return npx.batch_dot(self.dropout(self.attention_weights), values)
class DotProductAttention(nn.Module):  #@save
    """Scaled dot product attention."""
    dropout: float

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    @nn.compact
    def __call__(self, queries, keys, values, valid_lens=None,
                 training=False):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.swapaxes(1, 2)
        scores = queries@(keys.swapaxes(1, 2)) / math.sqrt(d)
        attention_weights = masked_softmax(scores, valid_lens)
        dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
        return dropout_layer(attention_weights)@values, attention_weights
class DotProductAttention(tf.keras.layers.Layer):  #@save
    """Scaled dot product attention."""
    def __init__(self, dropout):
        super().__init__()
        self.dropout = tf.keras.layers.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def call(self, queries, keys, values, valid_lens=None, **kwargs):
        d = queries.shape[-1]
        scores = tf.matmul(queries, keys, transpose_b=True)/tf.math.sqrt(
            tf.cast(d, dtype=tf.float32))
        self.attention_weights = masked_softmax(scores, valid_lens)
        return tf.matmul(self.dropout(self.attention_weights, **kwargs), values)

DotProductAttention クラスがどのように動作するかを示すために、先ほどの加法アテンションの玩具例と同じキー、値、有効長を使う。この例では、ミニバッチサイズを \(2\)、キーと値の総数を \(10\)、値の次元を \(4\) と仮定する。さらに、各観測の有効長はそれぞれ \(2\)\(6\) とする。すると、出力は \(2 \times 1 \times 4\) のテンソル、つまりミニバッチの各例につき1行になるはずである。

queries = d2l.normal(0, 1, (2, 1, 2))
keys = d2l.normal(0, 1, (2, 10, 2))
values = d2l.normal(0, 1, (2, 10, 4))
valid_lens = d2l.tensor([2, 6])

attention = DotProductAttention(dropout=0.5)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
queries = d2l.normal(0, 1, (2, 1, 2))
keys = d2l.normal(0, 1, (2, 10, 2))
values = d2l.normal(0, 1, (2, 10, 4))
valid_lens = d2l.tensor([2, 6])

attention = DotProductAttention(dropout=0.5)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
queries = jax.random.normal(d2l.get_key(), (2, 1, 2))
keys = jax.random.normal(d2l.get_key(), (2, 10, 2))
values = jax.random.normal(d2l.get_key(), (2, 10, 4))
valid_lens = d2l.tensor([2, 6])

attention = DotProductAttention(dropout=0.5)
(output, attention_weights), params = attention.init_with_output(
    d2l.get_key(), queries, keys, values, valid_lens)
print(output)
queries = tf.random.normal(shape=(2, 1, 2))
keys = tf.random.normal(shape=(2, 10, 2))
values = tf.random.normal(shape=(2, 10, 4))
valid_lens = tf.constant([2, 6])

attention = DotProductAttention(dropout=0.5)
d2l.check_shape(attention(queries, keys, values, valid_lens, training=False),
                (2, 1, 4))

アテンション重みが実際に、それぞれ第2列と第6列を超える部分で消えているか確認してみよう(有効長を \(2\)\(6\) に設定しているためである)。

d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(d2l.reshape(attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

11.3.4. 加法アテンション

クエリ \(\mathbf{q}\) とキー \(\mathbf{k}\) が異なる次元のベクトルであるとき、\(\mathbf{q}^\top \mathbf{M} \mathbf{k}\) によって次元の不一致を行列で調整する方法もあれば、スコアリング関数として加法アテンションを使う方法もある。もう1つの利点は、その名の通りアテンションが加法的であることである。これにより、わずかな計算量の節約が可能になる。クエリ \(\mathbf{q} \in \mathbb{R}^q\) とキー \(\mathbf{k} \in \mathbb{R}^k\) に対して、加法アテンション のスコアリング関数 (Bahdanau et al., 2014) は次のように与えられる。

(11.3.7)\[a(\mathbf q, \mathbf k) = \mathbf w_v^\top \textrm{tanh}(\mathbf W_q\mathbf q + \mathbf W_k \mathbf k) \in \mathbb{R},\]

ここで、\(\mathbf W_q\in\mathbb R^{h\times q}\)\(\mathbf W_k\in\mathbb R^{h\times k}\)\(\mathbf w_v\in\mathbb R^{h}\) は学習可能なパラメータである。この項をソフトマックスに入力して、非負性と正規化の両方を保証する。(11.3.7) の同値な解釈として、クエリとキーを連結し、1つの隠れ層を持つMLPに入力しているとみなすこともできる。活性化関数として \(\tanh\) を使い、バイアス項を無効にして、加法アテンションを次のように実装する。

class AdditiveAttention(nn.Module):  #@save
    """Additive attention."""
    def __init__(self, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.LazyLinear(num_hiddens, bias=False)
        self.W_q = nn.LazyLinear(num_hiddens, bias=False)
        self.w_v = nn.LazyLinear(1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, shape of queries: (batch_size, no. of
        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
        # key-value pairs, num_hiddens). Sum them up with broadcasting
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        features = torch.tanh(features)
        # There is only one output of self.w_v, so we remove the last
        # one-dimensional entry from the shape. Shape of scores: (batch_size,
        # no. of queries, no. of key-value pairs)
        scores = self.w_v(features).squeeze(-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Shape of values: (batch_size, no. of key-value pairs, value
        # dimension)
        return torch.bmm(self.dropout(self.attention_weights), values)
class AdditiveAttention(nn.Block):  #@save
    """Additive attention."""
    def __init__(self, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        # Use flatten=False to only transform the last axis so that the
        # shapes for the other axes are kept the same
        self.W_k = nn.Dense(num_hiddens, use_bias=False, flatten=False)
        self.W_q = nn.Dense(num_hiddens, use_bias=False, flatten=False)
        self.w_v = nn.Dense(1, use_bias=False, flatten=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, shape of queries: (batch_size, no. of
        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1,
        # no. of key-value pairs, num_hiddens). Sum them up with
        # broadcasting
        features = np.expand_dims(queries, axis=2) + np.expand_dims(
            keys, axis=1)
        features = np.tanh(features)
        # There is only one output of self.w_v, so we remove the last
        # one-dimensional entry from the shape. Shape of scores:
        # (batch_size, no. of queries, no. of key-value pairs)
        scores = np.squeeze(self.w_v(features), axis=-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Shape of values: (batch_size, no. of key-value pairs, value
        # dimension)
        return npx.batch_dot(self.dropout(self.attention_weights), values)
class AdditiveAttention(nn.Module):  #@save
    num_hiddens: int
    dropout: float

    def setup(self):
        self.W_k = nn.Dense(self.num_hiddens, use_bias=False)
        self.W_q = nn.Dense(self.num_hiddens, use_bias=False)
        self.w_v = nn.Dense(1, use_bias=False)

    @nn.compact
    def __call__(self, queries, keys, values, valid_lens, training=False):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, shape of queries: (batch_size, no. of
        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
        # key-value pairs, num_hiddens). Sum them up with broadcasting
        features = jnp.expand_dims(queries, axis=2) + jnp.expand_dims(keys, axis=1)
        features = nn.tanh(features)
        # There is only one output of self.w_v, so we remove the last
        # one-dimensional entry from the shape. Shape of scores: (batch_size,
        # no. of queries, no. of key-value pairs)
        scores = self.w_v(features).squeeze(-1)
        attention_weights = masked_softmax(scores, valid_lens)
        dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
        # Shape of values: (batch_size, no. of key-value pairs, value
        # dimension)
        return dropout_layer(attention_weights)@values, attention_weights
class AdditiveAttention(tf.keras.layers.Layer):  #@save
    """Additive attention."""
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super().__init__(**kwargs)
        self.W_k = tf.keras.layers.Dense(num_hiddens, use_bias=False)
        self.W_q = tf.keras.layers.Dense(num_hiddens, use_bias=False)
        self.w_v = tf.keras.layers.Dense(1, use_bias=False)
        self.dropout = tf.keras.layers.Dropout(dropout)

    def call(self, queries, keys, values, valid_lens, **kwargs):
        queries, keys = self.W_q(queries), self.W_k(keys)
        # After dimension expansion, shape of queries: (batch_size, no. of
        # queries, 1, num_hiddens) and shape of keys: (batch_size, 1, no. of
        # key-value pairs, num_hiddens). Sum them up with broadcasting
        features = tf.expand_dims(queries, axis=2) + tf.expand_dims(
            keys, axis=1)
        features = tf.nn.tanh(features)
        # There is only one output of self.w_v, so we remove the last
        # one-dimensional entry from the shape. Shape of scores: (batch_size,
        # no. of queries, no. of key-value pairs)
        scores = tf.squeeze(self.w_v(features), axis=-1)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Shape of values: (batch_size, no. of key-value pairs, value
        # dimension)
        return tf.matmul(self.dropout(
            self.attention_weights, **kwargs), values)

AdditiveAttention がどのように動作するかを見てみよう。玩具例では、クエリ、キー、値のサイズをそれぞれ \((2, 1, 20)\)\((2, 10, 2)\)\((2, 10, 4)\) とする。これは DotProductAttention のときの選択と同じであるが、今回はクエリが20次元である点が異なる。同様に、ミニバッチ内の系列の有効長として \((2, 6)\) を選ぶ。

queries = d2l.normal(0, 1, (2, 1, 20))

attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.eval()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
queries = d2l.normal(0, 1, (2, 1, 20))

attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
attention.initialize()
d2l.check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
queries = jax.random.normal(d2l.get_key(), (2, 1, 20))
attention = AdditiveAttention(num_hiddens=8, dropout=0.1)
(output, attention_weights), params = attention.init_with_output(
    d2l.get_key(), queries, keys, values, valid_lens)
print(output)
[[[-0.80551296 -0.46030337 -0.02557675  1.0253853 ]]

 [[ 0.25692716  0.2782622  -0.0186431   0.5796091 ]]]
queries = tf.random.normal(shape=(2, 1, 20))

attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,
                              dropout=0.1)
d2l.check_shape(attention(queries, keys, values, valid_lens, training=False),
                (2, 1, 4))

アテンション関数を確認すると、DotProductAttention の場合と質的にかなり似た振る舞いが見られる。つまり、選択した有効長 \((2, 6)\) の範囲内の項だけが非ゼロである。

d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(d2l.reshape(attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')
d2l.show_heatmaps(d2l.reshape(attention.attention_weights, (1, 1, 2, 10)),
                  xlabel='Keys', ylabel='Queries')

11.3.5. まとめ

この節では、2つの主要なアテンションのスコアリング関数、すなわち内積アテンションと加法アテンションを導入した。これらは、長さが可変な系列全体を集約するための有効な手段である。特に、内積アテンションは現代のTransformerアーキテクチャの中核をなしている。クエリとキーが異なる長さのベクトルである場合には、代わりに加法アテンションのスコアリング関数を使うことができる。これらの層を最適化することは、近年の進歩の重要な分野の1つである。たとえば、NVIDIA の Transformer Library や Megatron (Shoeybi et al., 2019) は、効率的なアテンション機構の変種に大きく依存している。後の節でTransformerを学ぶ際に、これについてさらに詳しく見ていく。

11.3.6. 演習

  1. DotProductAttention のコードを修正して、距離ベースのアテンションを実装せよ。効率的な実装には、キーの二乗ノルム \(\|\mathbf{k}_i\|^2\) だけが必要であることに注意せよ。

  2. 行列を用いて次元を調整することで、異なる次元のクエリとキーを扱えるように内積アテンションを修正せよ。

  3. 計算コストは、キー、クエリ、値の次元およびその個数に対してどのようにスケールするか。メモリ帯域幅の要件についてはどうか。