.. _sec_self-attention-and-positional-encoding:
自己注意機構と位置エンコーディング
==================================
深層学習では、系列をエンコードするためにCNNやRNNをよく用いる。
ここで注意機構を念頭に置くと、 トークン列を注意機構に入力し、
各ステップで各トークンがそれぞれ独自のクエリ、キー、値を持つ
と考えることができる。
ここで、次の層におけるあるトークンの表現の値を計算するとき、
そのトークンは(クエリベクトルを介して)他の任意のトークンに (キー
ベクトルに基づいて一致を取りながら)注意を向けることができる。
クエリとキーの適合度スコアの全体を用いることで、
各トークンについて、他のトークンに対する適切な重み付き和を構成し、
表現を計算できる。 各トークンが互いのトークンに注意を向けるため
(デコーダのステップがエンコーダのステップに注意を向ける場合とは異なり)、
このようなアーキテクチャは通常 *自己注意* モデル
:cite:`Lin.Feng.Santos.ea.2017,Vaswani.Shazeer.Parmar.ea.2017`
と呼ばれ、 別の文脈では *intra-attention* モデル
:cite:`Cheng.Dong.Lapata.2016,Parikh.Tackstrom.Das.ea.2016,Paulus.Xiong.Socher.2017`
とも呼ばれる。
この節では、系列の順序に関する追加情報も含めた、自己注意を用いる系列エンコーディングについて議論する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import torch as d2l
import math
import torch
from torch import nn
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import mxnet as d2l
import math
from mxnet import autograd, np, npx
from mxnet.gluon import nn
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import tensorflow as d2l
import numpy as np
import tensorflow as tf
.. raw:: html
.. raw:: html
自己注意
--------
入力トークン列 :math:`\mathbf{x}_1, \ldots, \mathbf{x}_n` があり、任意の
:math:`\mathbf{x}_i \in \mathbb{R}^d`\ (\ :math:`1 \leq i \leq n`\ )とする。
その自己注意の出力は 同じ長さの系列
:math:`\mathbf{y}_1, \ldots, \mathbf{y}_n` であり、ここで
.. math:: \mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d
は :eq:`eq_attention_pooling` における注意プーリングの定義に従う。
マルチヘッド注意を用いると、 次のコード片は
形状が(バッチサイズ、時間ステップ数またはトークン列長、\ :math:`d`\ )のテンソルに対する自己注意を計算する。
出力テンソルは同じ形状を持つ。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
batch_size, num_queries, valid_lens = 2, 4, d2l.tensor([3, 2])
X = d2l.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
attention.initialize()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_heads, 0.5)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
.. raw:: html
.. raw:: html
.. _subsec_cnn-rnn-self-attention:
CNN、RNN、自己注意の比較
------------------------
:math:`n` 個のトークンからなる系列を
同じ長さの別の系列へ写像するアーキテクチャを比較しよう。
ここで各入力トークンまたは出力トークンは :math:`d`
次元ベクトルで表される。 具体的には、 CNN、RNN、自己注意を考える。
それらの 計算量、 逐次演算、 最大経路長を比較する。
逐次演算は並列計算を妨げる一方で、
系列位置の任意の組み合わせ間の経路が短いほど、
系列内の長距離依存関係を学習しやすくなる
:cite:`Hochreiter.Bengio.Frasconi.ea.2001`\ 。
.. _fig_cnn-rnn-self-attention:
.. figure:: ../img/cnn-rnn-self-attention.svg
CNN(パディングトークンは省略)、RNN、自己注意アーキテクチャの比較。
テキスト系列を「一次元画像」とみなしてみよう。同様に、一次元CNNはテキスト中の
:math:`n`-gram のような局所特徴を処理できる。 長さ :math:`n`
の系列を考え、 カーネルサイズが :math:`k`\ 、
入力チャネル数と出力チャネル数がともに :math:`d` の畳み込み層を考える。
この畳み込み層の計算量は :math:`\mathcal{O}(knd^2)` である。
:numref:`fig_cnn-rnn-self-attention` が示すように、
CNNは階層的であるため、 逐次演算は :math:`\mathcal{O}(1)` で済み、
最大経路長は :math:`\mathcal{O}(n/k)` である。 たとえば、
:numref:`fig_cnn-rnn-self-attention` では、 :math:`\mathbf{x}_1` と
:math:`\mathbf{x}_5` は カーネルサイズ 3 の2層CNNの受容野内にある。
RNNの隠れ状態を更新するとき、 :math:`d \times d` の重み行列と :math:`d`
次元の隠れ状態の乗算の計算量は :math:`\mathcal{O}(d^2)` である。
系列長が :math:`n` なので、 再帰層の計算量は :math:`\mathcal{O}(nd^2)`
である。 :numref:`fig_cnn-rnn-self-attention` によれば、
並列化できない逐次演算が :math:`\mathcal{O}(n)` 回あり、 最大経路長も
:math:`\mathcal{O}(n)` である。
自己注意では、 クエリ、キー、値はすべて :math:`n \times d` 行列である。
:eq:`eq_softmax_QK_V` のスケールド・ドット積注意を考えると、
:math:`n \times d` 行列に :math:`d \times n` 行列を掛け、 その後、出力の
:math:`n \times n` 行列に :math:`n \times d` 行列を掛ける。 その結果、
自己注意の計算量は :math:`\mathcal{O}(n^2d)` になる。
:numref:`fig_cnn-rnn-self-attention` からわかるように、
各トークンは自己注意を通じて 他の任意のトークンに直接接続されている。
したがって、 計算は :math:`\mathcal{O}(1)` の逐次演算で並列に行え、
最大経路長も :math:`\mathcal{O}(1)` である。
要するに、 CNNと自己注意はいずれも並列計算の恩恵を受け、
自己注意は最大経路長が最も短い。
しかし、系列長に対して二次の計算量を持つため、
自己注意は非常に長い系列に対しては 極めて遅くなる。
.. _subsec_positional-encoding:
位置エンコーディング
--------------------
RNNが系列のトークンを 1つずつ再帰的に処理するのに対し、 自己注意は
逐次演算を捨てて 並列計算を優先する。 ただし、自己注意だけでは
系列の順序は保持されない。 入力系列がどの順序で到着したかを
モデルが知っていることが本当に重要な場合、 どうすればよいだろうか。
トークンの順序に関する情報を保持するための 主流の方法は、
各トークンに関連付けられた追加入力として
それをモデルに表現することである。 これらの入力は *位置エンコーディング*
と呼ばれ、 学習可能なものにも、あらかじめ固定されたものにもできる。
ここでは、正弦関数と余弦関数に基づく
固定位置エンコーディングの簡単な方式を説明する
:cite:`Vaswani.Shazeer.Parmar.ea.2017`\ 。
入力表現 :math:`\mathbf{X} \in \mathbb{R}^{n \times d}` が系列中の
:math:`n` 個のトークンの :math:`d` 次元埋め込みを含むとする。
位置エンコーディングは 同じ形状の位置埋め込み行列
:math:`\mathbf{P} \in \mathbb{R}^{n \times d}` を用いて
:math:`\mathbf{X} + \mathbf{P}` を出力する。 その :math:`i^\textrm{th}`
行 および :math:`(2j)^\textrm{th}` または :math:`(2j + 1)^\textrm{th}`
列の要素は
.. math:: \begin{aligned} p_{i, 2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right),\\p_{i, 2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right).\end{aligned}
:label: eq_positional-encoding-def
一見すると、 この三角関数を使った設計は奇妙に見える。
この設計の理由を説明する前に、 まず次の ``PositionalEncoding``
クラスで実装してみよう。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionalEncoding(nn.Module): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = d2l.zeros((1, max_len, num_hiddens))
X = d2l.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_queries, valid_lens = 2, 4, d2l.tensor([3, 2])
X = d2l.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens),
(batch_size, num_queries, num_hiddens))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[07:15:03] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_queries, valid_lens = 2, 4, d2l.tensor([3, 2])
X = d2l.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention.init_with_output(d2l.get_key(), X, X, X, valid_lens,
training=False)[0][0],
(batch_size, num_queries, num_hiddens))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
batch_size, num_queries, valid_lens = 2, 4, tf.constant([3, 2])
X = tf.ones((batch_size, num_queries, num_hiddens))
d2l.check_shape(attention(X, X, X, valid_lens, training=False),
(batch_size, num_queries, num_hiddens))
.. raw:: html
.. raw:: html
位置埋め込み行列 :math:`\mathbf{P}` では、 行は系列内の位置に対応し、
列は異なる位置エンコーディング次元を表す。 以下の例では、
位置埋め込み行列の :math:`6^{\textrm{th}}` 列と :math:`7^{\textrm{th}}`
列は、 :math:`8^{\textrm{th}}` 列と :math:`9^{\textrm{th}}` 列よりも
高い周波数を持つことがわかる。 :math:`6^{\textrm{th}}` 列と
:math:`7^{\textrm{th}}` 列 (\ :math:`8^{\textrm{th}}` 列と
:math:`9^{\textrm{th}}` 列も同様)の間のオフセットは、
正弦関数と余弦関数を交互に用いていることに由来する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(d2l.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(d2l.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in d2l.arange(6, 10)])
.. figure:: output_self-attention-and-positional-encoding_8df90e_48_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionalEncoding(nn.Block): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = nn.Dropout(dropout)
# Create a long enough P
self.P = d2l.zeros((1, max_len, num_hiddens))
X = d2l.arange(max_len).reshape(-1, 1) / np.power(
10000, np.arange(0, num_hiddens, 2) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].as_in_ctx(X.ctx)
return self.dropout(X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionalEncoding(nn.Module): #@save
"""Positional encoding."""
num_hiddens: int
dropout: float
max_len: int = 1000
def setup(self):
# Create a long enough P
self.P = d2l.zeros((1, self.max_len, self.num_hiddens))
X = d2l.arange(self.max_len, dtype=jnp.float32).reshape(
-1, 1) / jnp.power(10000, jnp.arange(
0, self.num_hiddens, 2, dtype=jnp.float32) / self.num_hiddens)
self.P = self.P.at[:, :, 0::2].set(jnp.sin(X))
self.P = self.P.at[:, :, 1::2].set(jnp.cos(X))
@nn.compact
def __call__(self, X, training=False):
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'P', self.P)
X = X + self.P[:, :X.shape[1], :]
return nn.Dropout(self.dropout)(X, deterministic=not training)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionalEncoding(tf.keras.layers.Layer): #@save
"""Positional encoding."""
def __init__(self, num_hiddens, dropout, max_len=1000):
super().__init__()
self.dropout = tf.keras.layers.Dropout(dropout)
# Create a long enough P
self.P = np.zeros((1, max_len, num_hiddens))
X = np.arange(max_len, dtype=np.float32).reshape(
-1,1)/np.power(10000, np.arange(
0, num_hiddens, 2, dtype=np.float32) / num_hiddens)
self.P[:, :, 0::2] = np.sin(X)
self.P[:, :, 1::2] = np.cos(X)
def call(self, X, **kwargs):
X = X + self.P[:, :X.shape[1], :]
return self.dropout(X, **kwargs)
.. raw:: html
.. raw:: html
絶対的な位置情報
~~~~~~~~~~~~~~~~
エンコーディング次元に沿って周波数が単調に減少することが
絶対的な位置情報とどのように関係するかを見るために、
:math:`0, 1, \ldots, 7` の2進表現を出力してみよう。
ご覧のように、最下位ビット、下から2番目のビット、
下から3番目のビットは、それぞれ1つおき、2つおき、4つおきに変化する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
for i in range(8):
print(f'{i} in binary is {i:>03b}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
pos_encoding.initialize()
X = pos_encoding(np.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(d2l.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in d2l.arange(6, 10)])
.. figure:: output_self-attention-and-positional-encoding_8df90e_66_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
params = pos_encoding.init(d2l.get_key(), d2l.zeros((1, num_steps, encoding_dim)))
X, inter_vars = pos_encoding.apply(params, d2l.zeros((1, num_steps, encoding_dim)),
mutable='intermediates')
P = inter_vars['intermediates']['P'][0] # retrieve intermediate value P
P = P[:, :X.shape[1], :]
d2l.plot(d2l.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in d2l.arange(6, 10)])
.. figure:: output_self-attention-and-positional-encoding_8df90e_69_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoding_dim, num_steps = 32, 60
pos_encoding = PositionalEncoding(encoding_dim, 0)
X = pos_encoding(tf.zeros((1, num_steps, encoding_dim)), training=False)
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(np.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6, 2.5), legend=["Col %d" % d for d in np.arange(6, 10)])
.. figure:: output_self-attention-and-positional-encoding_8df90e_72_0.svg
.. raw:: html
.. raw:: html
2進表現では、上位ビットほど下位ビットよりも周波数が低くなる。
同様に、下のヒートマップが示すように、
位置エンコーディングは三角関数を用いて
エンコーディング次元に沿って周波数を減少させる。
出力は浮動小数点数なので、 このような連続表現は 2進表現よりも
空間効率が高い。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.5, 4), cmap='Blues')
.. figure:: output_self-attention-and-positional-encoding_8df90e_78_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
for i in range(8):
print(f'{i} in binary is {i:>03b}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
for i in range(8):
print(f'{i} in binary is {i:>03b}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
for i in range(8):
print(f'{i} in binary is {i:>03b}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
0 in binary is 000
1 in binary is 001
2 in binary is 010
3 in binary is 011
4 in binary is 100
5 in binary is 101
6 in binary is 110
7 in binary is 111
.. raw:: html
.. raw:: html
相対的な位置情報
~~~~~~~~~~~~~~~~
絶対的な位置情報を捉えることに加えて、 上記の位置エンコーディングは
モデルが相対位置に基づいて容易に注意を学習することも可能にする。
これは、 任意の固定位置オフセット :math:`\delta` に対して、 位置
:math:`i + \delta` における位置エンコーディングが 位置 :math:`i`
におけるものの線形射影として表せるからである。
この射影は 数学的に説明できる。 :math:`\omega_j = 1/10000^{2j/d}`
とおくと、 :eq:`eq_positional-encoding-def` における任意の
:math:`(p_{i, 2j}, p_{i, 2j+1})` の組は、 任意の固定オフセット
:math:`\delta` に対して :math:`(p_{i+\delta, 2j}, p_{i+\delta, 2j+1})`
へ線形射影できる。
.. math::
\begin{aligned}
\begin{bmatrix} \cos(\delta \omega_j) & \sin(\delta \omega_j) \\ -\sin(\delta \omega_j) & \cos(\delta \omega_j) \\ \end{bmatrix}
\begin{bmatrix} p_{i, 2j} \\ p_{i, 2j+1} \\ \end{bmatrix}
=&\begin{bmatrix} \cos(\delta \omega_j) \sin(i \omega_j) + \sin(\delta \omega_j) \cos(i \omega_j) \\ -\sin(\delta \omega_j) \sin(i \omega_j) + \cos(\delta \omega_j) \cos(i \omega_j) \\ \end{bmatrix}\\
=&\begin{bmatrix} \sin\left((i+\delta) \omega_j\right) \\ \cos\left((i+\delta) \omega_j\right) \\ \end{bmatrix}\\
=&
\begin{bmatrix} p_{i+\delta, 2j} \\ p_{i+\delta, 2j+1} \\ \end{bmatrix},
\end{aligned}
ここで、\ :math:`2\times 2` の射影行列は どの位置インデックス :math:`i`
にも依存しない。
まとめ
------
自己注意では、クエリ、キー、値はすべて同じ場所から来る。
CNNと自己注意はいずれも並列計算の恩恵を受け、
自己注意は最大経路長が最も短い。
しかし、系列長に対して二次の計算量を持つため、
自己注意は非常に長い系列に対しては 極めて遅くなる。
系列順序の情報を使うには、
入力表現に位置エンコーディングを加えることで、
絶対的または相対的な位置情報を注入できる。
演習
----
1. 位置エンコーディングを用いた自己注意層を積み重ねることで系列を表現する深いアーキテクチャを設計するとする。どのような問題が起こりうるだろうか。
2. 学習可能な位置エンコーディング手法を設計できるか?
3. 自己注意で比較されるクエリとキーの間の異なるオフセットに応じて、異なる学習済み埋め込みを割り当てることはできるか?
ヒント:相対位置埋め込みを参照されたい
:cite:`shaw2018self,huang2018music`\ 。