.. _sec_transformer:
Transformerアーキテクチャ
=========================
:numref:`subsec_cnn-rnn-self-attention`
では、CNN、RNN、自己注意を比較した。 特に、自己注意は並列計算と
最短の最大経路長の両方を備えている。 したがって、
自己注意を用いて深いアーキテクチャを設計するのは魅力的である。
入力表現に対してなおRNNに依存していた 以前の自己注意モデル
:cite:`Cheng.Dong.Lapata.2016,Lin.Feng.Santos.ea.2017,Paulus.Xiong.Socher.2017`
とは異なり、 Transformerモデルは 畳み込み層も再帰層も使わず、
注意機構のみに基づいている :cite:`Vaswani.Shazeer.Parmar.ea.2017`\ 。
もともとは テキストデータに対する系列変換学習のために提案されたが、
Transformerは 言語、視覚、音声、強化学習など、
現代の深層学習の幅広い応用分野で 広く使われるようになっている。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import torch as d2l
import math
import pandas as pd
import torch
from torch import nn
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import mxnet as d2l
import math
from mxnet import autograd, init, np, npx
from mxnet.gluon import nn
import pandas as pd
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
import jax
import math
import pandas as pd
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import tensorflow as d2l
import numpy as np
import pandas as pd
import tensorflow as tf
.. raw:: html
.. raw:: html
モデル
------
エンコーダ–デコーダ アーキテクチャの一例として、
Transformerの全体アーキテクチャを :numref:`fig_transformer` に示す。
見てわかるように、 Transformerはエンコーダとデコーダから構成される。
:numref:`fig_s2s_attention_details` の
Bahdanau注意による系列変換学習とは対照的に、
入力(source)系列と出力(target)系列の埋め込みには、
自己注意に基づくモジュールを積み重ねた
エンコーダとデコーダに入力する前に、 位置エンコーディングが加えられる。
.. _fig_transformer:
.. figure:: ../img/transformer.svg
:width: 320px
The Transformer architecture.
ここで、 :numref:`fig_transformer` における
Transformerアーキテクチャの概要を説明する。 高レベルでは、
Transformerエンコーダは複数の同一層のスタックであり、
各層は2つのサブレイヤー(いずれも :math:`\textrm{sublayer}`
と表記)を持つ。 1つ目は マルチヘッド自己注意プーリングであり、
2つ目は位置ごとのフィードフォワードネットワークである。 具体的には、
エンコーダの自己注意では、 クエリ、キー、値はすべて
前のエンコーダ層の出力から得られる。 :numref:`sec_resnet`
のResNet設計に着想を得て、 両方のサブレイヤーの周囲に
残差接続が用いられる。 Transformerでは、
系列の任意の位置にある任意の入力 :math:`\mathbf{x} \in \mathbb{R}^d`
に対して、 残差接続
:math:`\mathbf{x} + \textrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d`
が可能であるように、
:math:`\textrm{sublayer}(\mathbf{x}) \in \mathbb{R}^d` を要求する。
この残差接続による加算の直後に 層正規化が続く
:cite:`Ba.Kiros.Hinton.2016`\ 。
その結果、Transformerエンコーダは入力系列の各位置に対して :math:`d`
次元のベクトル表現を出力する。
Transformerデコーダもまた、 残差接続と層正規化を備えた
複数の同一層のスタックである。
エンコーダで説明した2つのサブレイヤーに加えて、 デコーダは その間に
エンコーダ–デコーダ注意と呼ばれる 3つ目のサブレイヤーを挿入する。
エンコーダ–デコーダ注意では、 クエリは
デコーダの自己注意サブレイヤーの出力から得られ、 キーと値は
Transformerエンコーダの出力から得られる。 デコーダの自己注意では、
クエリ、キー、値はすべて 前のデコーダ層の出力から得られる。
ただし、デコーダの各位置は その位置までのデコーダ内のすべての位置にのみ
注意を向けることが許される。 この\ *マスク付き*\ 注意は
自己回帰性を保ち、
予測が生成済みの出力トークンのみに依存することを保証する。
すでに :numref:`sec_multihead-attention` で
スケールド・ドット積に基づく マルチヘッド注意と、
:numref:`subsec_positional-encoding` で
位置エンコーディングを説明し実装した。 以下では、
Transformerモデルの残りの部分を実装する。
.. _subsec_positionwise-ffn:
位置ごとのフィードフォワードネットワーク
----------------------------------------
位置ごとのフィードフォワードネットワークは、 同じMLPを用いて
すべての系列位置の表現を変換する。
このため、これを\ *位置ごと*\ と呼ぶ。 以下の実装では、 形状が
(バッチサイズ、時間ステップ数またはトークン単位の系列長、
隠れユニット数または特徴次元) である入力 ``X`` は、 2層MLPによって
形状が (バッチサイズ、時間ステップ数、\ ``ffn_num_outputs``\ )
の出力テンソルへ変換される。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionWiseFFN(nn.Module): #@save
"""The positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
super().__init__()
self.dense1 = nn.LazyLinear(ffn_num_hiddens)
self.relu = nn.ReLU()
self.dense2 = nn.LazyLinear(ffn_num_outputs)
def forward(self, X):
return self.dense2(self.relu(self.dense1(X)))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionWiseFFN(nn.Block): #@save
"""The positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
super().__init__()
self.dense1 = nn.Dense(ffn_num_hiddens, flatten=False,
activation='relu')
self.dense2 = nn.Dense(ffn_num_outputs, flatten=False)
def forward(self, X):
return self.dense2(self.dense1(X))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionWiseFFN(nn.Module): #@save
"""The positionwise feed-forward network."""
ffn_num_hiddens: int
ffn_num_outputs: int
def setup(self):
self.dense1 = nn.Dense(self.ffn_num_hiddens)
self.dense2 = nn.Dense(self.ffn_num_outputs)
def __call__(self, X):
return self.dense2(nn.relu(self.dense1(X)))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class PositionWiseFFN(tf.keras.layers.Layer): #@save
"""The positionwise feed-forward network."""
def __init__(self, ffn_num_hiddens, ffn_num_outputs):
super().__init__()
self.dense1 = tf.keras.layers.Dense(ffn_num_hiddens)
self.relu = tf.keras.layers.ReLU()
self.dense2 = tf.keras.layers.Dense(ffn_num_outputs)
def call(self, X):
return self.dense2(self.relu(self.dense1(X)))
.. raw:: html
.. raw:: html
次の例は、 テンソルの最内側の次元が変化することを示している。
その変化先は 位置ごとのフィードフォワードネットワークの出力数である。
同じMLPが すべての位置で変換を行うため、
それらすべての位置で入力が同じなら、 出力も同一になる。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ffn = PositionWiseFFN(4, 8)
ffn.eval()
ffn(d2l.ones((2, 3, 4)))[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[-0.4807, 0.1850, 0.1983, 0.0662, 0.3201, 0.5081, 0.1139, -0.1333],
[-0.4807, 0.1850, 0.1983, 0.0662, 0.3201, 0.5081, 0.1139, -0.1333],
[-0.4807, 0.1850, 0.1983, 0.0662, 0.3201, 0.5081, 0.1139, -0.1333]],
grad_fn=)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ffn = PositionWiseFFN(4, 8)
ffn.initialize()
ffn(np.ones((2, 3, 4)))[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[07:43:27] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[ 0.00239431, 0.00927085, -0.00021069, -0.00923989, -0.0082903 ,
-0.00162741, 0.00659031, 0.00023905],
[ 0.00239431, 0.00927085, -0.00021069, -0.00923989, -0.0082903 ,
-0.00162741, 0.00659031, 0.00023905],
[ 0.00239431, 0.00927085, -0.00021069, -0.00923989, -0.0082903 ,
-0.00162741, 0.00659031, 0.00023905]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ffn = PositionWiseFFN(4, 8)
ffn.init_with_output(d2l.get_key(), jnp.ones((2, 3, 4)))[0][0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Array([[-0.5607809 , 0.00511155, -0.15698639, -0.46223652, -0.06088082,
0.41905117, -0.38441786, 0.28503913],
[-0.5607809 , 0.00511155, -0.15698639, -0.46223652, -0.06088082,
0.41905117, -0.38441786, 0.28503913],
[-0.5607809 , 0.00511155, -0.15698639, -0.46223652, -0.06088082,
0.41905117, -0.38441786, 0.28503913]], dtype=float32)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ffn = PositionWiseFFN(4, 8)
ffn(tf.ones((2, 3, 4)))[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
残差接続と層正規化
------------------
ここでは、 :numref:`fig_transformer` の「add &
norm」コンポーネントに注目しよう。 この節の冒頭で述べたように、
これは残差接続の直後に 層正規化が続く構成である。
どちらも効果的な深層アーキテクチャにとって重要である。
:numref:`sec_batch_norm` では、 バッチ正規化が
ミニバッチ内の各例にわたって
どのように再中心化と再スケーリングを行うかを説明した。
:numref:`subsec_layer-normalization-in-bn` で議論したように、
層正規化はバッチ正規化と同じだが、
前者は特徴次元に沿って正規化する点が異なる。
そのため、スケール不変性とバッチサイズ不変性の利点を持つ。
コンピュータビジョンで広く使われているにもかかわらず、 バッチ正規化は、
入力がしばしば可変長系列である自然言語処理タスクでは、
経験的に層正規化ほど有効でないことが多い。
次のコード片は、
層正規化とバッチ正規化による異なる次元に沿った正規化を比較する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ln = nn.LayerNorm(2)
bn = nn.LazyBatchNorm1d()
X = d2l.tensor([[1, 2], [2, 3]], dtype=torch.float32)
# Compute mean and variance from X in the training mode
print('layer norm:', ln(X), '\nbatch norm:', bn(X))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
layer norm: tensor([[-1.0000, 1.0000],
[-1.0000, 1.0000]], grad_fn=)
batch norm: tensor([[-1.0000, -1.0000],
[ 1.0000, 1.0000]], grad_fn=)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ln = nn.LayerNorm()
ln.initialize()
bn = nn.BatchNorm()
bn.initialize()
X = d2l.tensor([[1, 2], [2, 3]])
# Compute mean and variance from X in the training mode
with autograd.record():
print('layer norm:', ln(X), '\nbatch norm:', bn(X))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
layer norm: [[-0.99998 0.99998]
[-0.99998 0.99998]]
batch norm: [[-0.99998 -0.99998]
[ 0.99998 0.99998]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ln = nn.LayerNorm()
bn = nn.BatchNorm()
X = d2l.tensor([[1, 2], [2, 3]], dtype=d2l.float32)
# Compute mean and variance from X in the training mode
print('layer norm:', ln.init_with_output(d2l.get_key(), X)[0],
'\nbatch norm:', bn.init_with_output(d2l.get_key(), X,
use_running_average=False)[0])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
layer norm: [[-0.9999979 0.9999979]
[-0.9999979 0.9999979]]
batch norm: [[-0.9999799 -0.9999799]
[ 0.9999799 0.9999799]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
ln = tf.keras.layers.LayerNormalization()
bn = tf.keras.layers.BatchNormalization()
X = tf.constant([[1, 2], [2, 3]], dtype=tf.float32)
print('layer norm:', ln(X), '\nbatch norm:', bn(X, training=True))
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
layer norm: tf.Tensor(
[[-0.998006 0.9980061]
[-0.9980061 0.998006 ]], shape=(2, 2), dtype=float32)
batch norm: tf.Tensor(
[[-0.998006 -0.9980061 ]
[ 0.9980061 0.99800587]], shape=(2, 2), dtype=float32)
.. raw:: html
.. raw:: html
ここで、残差接続の後に層正規化を行う ``AddNorm`` クラスを 実装できる。
正則化のためにドロップアウトも適用する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AddNorm(nn.Module): #@save
"""The residual connection followed by layer normalization."""
def __init__(self, norm_shape, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm(norm_shape)
def forward(self, X, Y):
return self.ln(self.dropout(Y) + X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AddNorm(nn.Block): #@save
"""The residual connection followed by layer normalization."""
def __init__(self, dropout):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.ln = nn.LayerNorm()
def forward(self, X, Y):
return self.ln(self.dropout(Y) + X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AddNorm(nn.Module): #@save
"""The residual connection followed by layer normalization."""
dropout: int
@nn.compact
def __call__(self, X, Y, training=False):
return nn.LayerNorm()(
nn.Dropout(self.dropout)(Y, deterministic=not training) + X)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class AddNorm(tf.keras.layers.Layer): #@save
"""The residual connection followed by layer normalization."""
def __init__(self, norm_shape, dropout):
super().__init__()
self.dropout = tf.keras.layers.Dropout(dropout)
self.ln = tf.keras.layers.LayerNormalization(norm_shape)
def call(self, X, Y, **kwargs):
return self.ln(self.dropout(Y, **kwargs) + X)
.. raw:: html
.. raw:: html
残差接続では、 加算演算の後も出力テンソルが同じ形状になるように、
2つの入力が同じ形状であることが必要である。
そのため、出力テンソルも加算後に同じ形状を持つ。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
add_norm = AddNorm(4, 0.5)
shape = (2, 3, 4)
d2l.check_shape(add_norm(d2l.ones(shape), d2l.ones(shape)), shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
add_norm = AddNorm(0.5)
add_norm.initialize()
shape = (2, 3, 4)
d2l.check_shape(add_norm(d2l.ones(shape), d2l.ones(shape)), shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
add_norm = AddNorm(0.5)
shape = (2, 3, 4)
output, _ = add_norm.init_with_output(d2l.get_key(), d2l.ones(shape),
d2l.ones(shape))
d2l.check_shape(output, shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Normalized_shape is: [i for i in range(len(input.shape))][1:]
add_norm = AddNorm([1, 2], 0.5)
shape = (2, 3, 4)
d2l.check_shape(add_norm(tf.ones(shape), tf.ones(shape), training=False),
shape)
.. raw:: html
.. raw:: html
.. _subsec_transformer-encoder:
エンコーダ
----------
Transformerエンコーダを構成するための 必要な要素がすべてそろったので、
まずは エンコーダ内の1層を実装しよう。 以下の
``TransformerEncoderBlock`` クラスは 2つのサブレイヤーを含む。
すなわち、マルチヘッド自己注意と位置ごとのフィードフォワードネットワークであり、
両方のサブレイヤーの周囲に 残差接続と層正規化が用いられる。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerEncoderBlock(nn.Module): #@save
"""The Transformer encoder block."""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,
use_bias=False):
super().__init__()
self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout, use_bias)
self.addnorm1 = AddNorm(num_hiddens, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(num_hiddens, dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerEncoderBlock(nn.Block): #@save
"""The Transformer encoder block."""
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout,
use_bias=False):
super().__init__()
self.attention = d2l.MultiHeadAttention(
num_hiddens, num_heads, dropout, use_bias)
self.addnorm1 = AddNorm(dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(dropout)
def forward(self, X, valid_lens):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))
return self.addnorm2(Y, self.ffn(Y))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerEncoderBlock(nn.Module): #@save
"""The Transformer encoder block."""
num_hiddens: int
ffn_num_hiddens: int
num_heads: int
dropout: float
use_bias: bool = False
def setup(self):
self.attention = d2l.MultiHeadAttention(self.num_hiddens, self.num_heads,
self.dropout, self.use_bias)
self.addnorm1 = AddNorm(self.dropout)
self.ffn = PositionWiseFFN(self.ffn_num_hiddens, self.num_hiddens)
self.addnorm2 = AddNorm(self.dropout)
def __call__(self, X, valid_lens, training=False):
output, attention_weights = self.attention(X, X, X, valid_lens,
training=training)
Y = self.addnorm1(X, output, training=training)
return self.addnorm2(Y, self.ffn(Y), training=training), attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerEncoderBlock(tf.keras.layers.Layer): #@save
"""The Transformer encoder block."""
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, dropout, bias=False):
super().__init__()
self.attention = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout,
bias)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm2 = AddNorm(norm_shape, dropout)
def call(self, X, valid_lens, **kwargs):
Y = self.addnorm1(X, self.attention(X, X, X, valid_lens, **kwargs),
**kwargs)
return self.addnorm2(Y, self.ffn(Y), **kwargs)
.. raw:: html
.. raw:: html
見てわかるように、 Transformerエンコーダのどの層も
入力の形状を変えない。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = d2l.ones((2, 100, 24))
valid_lens = d2l.tensor([3, 2])
encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)
encoder_blk.eval()
d2l.check_shape(encoder_blk(X, valid_lens), X.shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = d2l.ones((2, 100, 24))
valid_lens = d2l.tensor([3, 2])
encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)
encoder_blk.initialize()
d2l.check_shape(encoder_blk(X, valid_lens), X.shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = jnp.ones((2, 100, 24))
valid_lens = jnp.array([3, 2])
encoder_blk = TransformerEncoderBlock(24, 48, 8, 0.5)
(output, _), _ = encoder_blk.init_with_output(d2l.get_key(), X, valid_lens,
training=False)
d2l.check_shape(output, X.shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.ones((2, 100, 24))
valid_lens = tf.constant([3, 2])
norm_shape = [i for i in range(len(X.shape))][1:]
encoder_blk = TransformerEncoderBlock(24, 24, 24, 24, norm_shape, 48, 8, 0.5)
d2l.check_shape(encoder_blk(X, valid_lens, training=False), X.shape)
.. raw:: html
.. raw:: html
以下のTransformerエンコーダの実装では、 上記の
``TransformerEncoderBlock`` クラスの ``num_blks``
個のインスタンスを積み重ねる。 固定位置エンコーディングの値は常に
:math:`-1` と :math:`1` の間にあるため、
入力埋め込みと位置エンコーディングを加算する前に、
学習可能な入力埋め込みの値に 埋め込み次元の平方根を掛けて
再スケーリングする。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerEncoder(d2l.Encoder): #@save
"""The Transformer encoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, use_bias=False):
super().__init__()
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_blks):
self.blks.add_module("block"+str(i), TransformerEncoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias))
def forward(self, X, valid_lens):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[
i] = blk.attention.attention.attention_weights
return X
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerEncoder(d2l.Encoder): #@save
"""The Transformer encoder."""
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
num_heads, num_blks, dropout, use_bias=False):
super().__init__()
self.num_hiddens = num_hiddens
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for _ in range(num_blks):
self.blks.add(TransformerEncoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, use_bias))
self.initialize()
def forward(self, X, valid_lens):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens)
self.attention_weights[
i] = blk.attention.attention.attention_weights
return X
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerEncoder(d2l.Encoder): #@save
"""The Transformer encoder."""
vocab_size: int
num_hiddens:int
ffn_num_hiddens: int
num_heads: int
num_blks: int
dropout: float
use_bias: bool = False
def setup(self):
self.embedding = nn.Embed(self.vocab_size, self.num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(self.num_hiddens, self.dropout)
self.blks = [TransformerEncoderBlock(self.num_hiddens,
self.ffn_num_hiddens,
self.num_heads,
self.dropout, self.use_bias)
for _ in range(self.num_blks)]
def __call__(self, X, valid_lens, training=False):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.embedding(X) * math.sqrt(self.num_hiddens)
X = self.pos_encoding(X, training=training)
attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X, attention_w = blk(X, valid_lens, training=training)
attention_weights[i] = attention_w
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'enc_attention_weights', attention_weights)
return X
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerEncoder(d2l.Encoder): #@save
"""The Transformer encoder."""
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_hiddens, num_heads,
num_blks, dropout, bias=False):
super().__init__()
self.num_hiddens = num_hiddens
self.embedding = tf.keras.layers.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = [TransformerEncoderBlock(
key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_hiddens, num_heads, dropout, bias) for _ in range(
num_blks)]
def call(self, X, valid_lens, **kwargs):
# Since positional encoding values are between -1 and 1, the embedding
# values are multiplied by the square root of the embedding dimension
# to rescale before they are summed up
X = self.pos_encoding(self.embedding(X) * tf.math.sqrt(
tf.cast(self.num_hiddens, dtype=tf.float32)), **kwargs)
self.attention_weights = [None] * len(self.blks)
for i, blk in enumerate(self.blks):
X = blk(X, valid_lens, **kwargs)
self.attention_weights[
i] = blk.attention.attention.attention_weights
return X
.. raw:: html
.. raw:: html
以下では、ハイパーパラメータを指定して
2層のTransformerエンコーダを作成する。 Transformerエンコーダの出力形状は
(バッチサイズ、時間ステップ数、\ ``num_hiddens``\ )である。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)
d2l.check_shape(encoder(d2l.ones((2, 100), dtype=torch.long), valid_lens),
(2, 100, 24))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)
d2l.check_shape(encoder(np.ones((2, 100)), valid_lens), (2, 100, 24))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoder = TransformerEncoder(200, 24, 48, 8, 2, 0.5)
d2l.check_shape(encoder.init_with_output(d2l.get_key(),
jnp.ones((2, 100), dtype=jnp.int32),
valid_lens)[0],
(2, 100, 24))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
encoder = TransformerEncoder(200, 24, 24, 24, 24, [1, 2], 48, 8, 2, 0.5)
d2l.check_shape(encoder(tf.ones((2, 100)), valid_lens, training=False),
(2, 100, 24))
.. raw:: html
.. raw:: html
デコーダ
--------
:numref:`fig_transformer` に示すように、
Transformerデコーダは複数の同一層から構成される。 各層は以下の
``TransformerDecoderBlock`` クラスで実装され、 3つのサブレイヤーを含む。
すなわち、 デコーダ自己注意、 エンコーダ–デコーダ注意、
位置ごとのフィードフォワードネットワークである。
これらのサブレイヤーには、 それぞれの周囲に
残差接続があり、その後に層正規化が続く。
この節の前半で述べたように、 マスク付きマルチヘッドデコーダ自己注意
(1つ目のサブレイヤー)では、 クエリ、キー、値は
すべて前のデコーダ層の出力から得られる。 系列変換モデルを学習するとき、
出力系列のすべての位置(時間ステップ)のトークンは 既知である。 しかし、
予測時には 出力系列はトークンごとに生成される。 したがって、
デコーダの任意の時間ステップでは 生成済みトークンのみを
デコーダ自己注意に使うことができる。 デコーダの自己回帰性を保つために、
マスク付き自己注意は ``dec_valid_lens`` を指定し、 各クエリが
デコーダ内のクエリ位置までの すべての位置にのみ注意を向けるようにする。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoderBlock(nn.Module):
# The i-th block in the Transformer decoder
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i):
super().__init__()
self.i = i
self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout)
self.addnorm1 = AddNorm(num_hiddens, dropout)
self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout)
self.addnorm2 = AddNorm(num_hiddens, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(num_hiddens, dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
else:
key_values = torch.cat((state[2][self.i], X), dim=1)
state[2][self.i] = key_values
if self.training:
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = torch.arange(
1, num_steps + 1, device=X.device).repeat(batch_size, 1)
else:
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoderBlock(nn.Block):
# The i-th block in the Transformer decoder
def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, i):
super().__init__()
self.i = i
self.attention1 = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout)
self.addnorm1 = AddNorm(dropout)
self.attention2 = d2l.MultiHeadAttention(num_hiddens, num_heads,
dropout)
self.addnorm2 = AddNorm(dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(dropout)
def forward(self, X, state):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
else:
key_values = np.concatenate((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if autograd.is_training():
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = np.tile(np.arange(1, num_steps + 1, ctx=X.ctx),
(batch_size, 1))
else:
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens)
Y = self.addnorm1(X, X2)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens)
Z = self.addnorm2(Y, Y2)
return self.addnorm3(Z, self.ffn(Z)), state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoderBlock(nn.Module):
# The i-th block in the Transformer decoder
num_hiddens: int
ffn_num_hiddens: int
num_heads: int
dropout: float
i: int
def setup(self):
self.attention1 = d2l.MultiHeadAttention(self.num_hiddens,
self.num_heads,
self.dropout)
self.addnorm1 = AddNorm(self.dropout)
self.attention2 = d2l.MultiHeadAttention(self.num_hiddens,
self.num_heads,
self.dropout)
self.addnorm2 = AddNorm(self.dropout)
self.ffn = PositionWiseFFN(self.ffn_num_hiddens, self.num_hiddens)
self.addnorm3 = AddNorm(self.dropout)
def __call__(self, X, state, training=False):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
else:
key_values = jnp.concatenate((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if training:
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = jnp.tile(jnp.arange(1, num_steps + 1),
(batch_size, 1))
else:
dec_valid_lens = None
# Self-attention
X2, attention_w1 = self.attention1(X, key_values, key_values,
dec_valid_lens, training=training)
Y = self.addnorm1(X, X2, training=training)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2, attention_w2 = self.attention2(Y, enc_outputs, enc_outputs,
enc_valid_lens, training=training)
Z = self.addnorm2(Y, Y2, training=training)
return self.addnorm3(Z, self.ffn(Z), training=training), state, attention_w1, attention_w2
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoderBlock(tf.keras.layers.Layer):
# The i-th block in the Transformer decoder
def __init__(self, key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, dropout, i):
super().__init__()
self.i = i
self.attention1 = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm1 = AddNorm(norm_shape, dropout)
self.attention2 = d2l.MultiHeadAttention(
key_size, query_size, value_size, num_hiddens, num_heads, dropout)
self.addnorm2 = AddNorm(norm_shape, dropout)
self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
self.addnorm3 = AddNorm(norm_shape, dropout)
def call(self, X, state, **kwargs):
enc_outputs, enc_valid_lens = state[0], state[1]
# During training, all the tokens of any output sequence are processed
# at the same time, so state[2][self.i] is None as initialized. When
# decoding any output sequence token by token during prediction,
# state[2][self.i] contains representations of the decoded output at
# the i-th block up to the current time step
if state[2][self.i] is None:
key_values = X
else:
key_values = tf.concat((state[2][self.i], X), axis=1)
state[2][self.i] = key_values
if kwargs["training"]:
batch_size, num_steps, _ = X.shape
# Shape of dec_valid_lens: (batch_size, num_steps), where every
# row is [1, 2, ..., num_steps]
dec_valid_lens = tf.repeat(
tf.reshape(tf.range(1, num_steps + 1),
shape=(-1, num_steps)), repeats=batch_size, axis=0)
else:
dec_valid_lens = None
# Self-attention
X2 = self.attention1(X, key_values, key_values, dec_valid_lens,
**kwargs)
Y = self.addnorm1(X, X2, **kwargs)
# Encoder-decoder attention. Shape of enc_outputs:
# (batch_size, num_steps, num_hiddens)
Y2 = self.attention2(Y, enc_outputs, enc_outputs, enc_valid_lens,
**kwargs)
Z = self.addnorm2(Y, Y2, **kwargs)
return self.addnorm3(Z, self.ffn(Z), **kwargs), state
.. raw:: html
.. raw:: html
エンコーダ–デコーダ注意におけるスケールド・ドット積演算と
残差接続における加算演算を容易にするため、
デコーダの特徴次元(\ ``num_hiddens``\ )は
エンコーダのそれと同じである。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
decoder_blk = TransformerDecoderBlock(24, 48, 8, 0.5, 0)
X = d2l.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
d2l.check_shape(decoder_blk(X, state)[0], X.shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
decoder_blk = TransformerDecoderBlock(24, 48, 8, 0.5, 0)
decoder_blk.initialize()
X = np.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
d2l.check_shape(decoder_blk(X, state)[0], X.shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
decoder_blk = TransformerDecoderBlock(24, 48, 8, 0.5, 0)
X = d2l.ones((2, 100, 24))
state = [encoder_blk.init_with_output(d2l.get_key(), X, valid_lens)[0][0],
valid_lens, [None]]
d2l.check_shape(decoder_blk.init_with_output(d2l.get_key(), X, state)[0][0],
X.shape)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
decoder_blk = TransformerDecoderBlock(24, 24, 24, 24, [1, 2], 48, 8, 0.5, 0)
X = tf.ones((2, 100, 24))
state = [encoder_blk(X, valid_lens), valid_lens, [None]]
d2l.check_shape(decoder_blk(X, state, training=False)[0], X.shape)
.. raw:: html
.. raw:: html
ここで、\ ``num_blks`` 個の ``TransformerDecoderBlock``
インスタンスからなる Transformerデコーダ全体を構成する。 最後に、
全結合層が ``vocab_size``
個の可能な出力トークンすべてに対する予測を計算する。
デコーダ自己注意の重みと エンコーダ–デコーダ注意の重みの両方は、
後で可視化できるように保存される。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout):
super().__init__()
self.num_hiddens = num_hiddens
self.num_blks = num_blks
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_blks):
self.blks.add_module("block"+str(i), TransformerDecoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.LazyLinear(vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# Decoder self-attention weights
self._attention_weights[0][
i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
self._attention_weights[1][
i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout):
super().__init__()
self.num_hiddens = num_hiddens
self.num_blks = num_blks
self.embedding = nn.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = nn.Sequential()
for i in range(num_blks):
self.blks.add(TransformerDecoderBlock(
num_hiddens, ffn_num_hiddens, num_heads, dropout, i))
self.dense = nn.Dense(vocab_size, flatten=False)
self.initialize()
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def forward(self, X, state):
X = self.pos_encoding(self.embedding(X) * math.sqrt(self.num_hiddens))
self._attention_weights = [[None] * len(self.blks) for _ in range (2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state)
# Decoder self-attention weights
self._attention_weights[0][
i] = blk.attention1.attention.attention_weights
# Encoder-decoder attention weights
self._attention_weights[1][
i] = blk.attention2.attention.attention_weights
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoder(nn.Module):
vocab_size: int
num_hiddens: int
ffn_num_hiddens: int
num_heads: int
num_blks: int
dropout: float
def setup(self):
self.embedding = nn.Embed(self.vocab_size, self.num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(self.num_hiddens,
self.dropout)
self.blks = [TransformerDecoderBlock(self.num_hiddens,
self.ffn_num_hiddens,
self.num_heads, self.dropout, i)
for i in range(self.num_blks)]
self.dense = nn.Dense(self.vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def __call__(self, X, state, training=False):
X = self.embedding(X) * jnp.sqrt(jnp.float32(self.num_hiddens))
X = self.pos_encoding(X, training=training)
attention_weights = [[None] * len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state, attention_w1, attention_w2 = blk(X, state,
training=training)
# Decoder self-attention weights
attention_weights[0][i] = attention_w1
# Encoder-decoder attention weights
attention_weights[1][i] = attention_w2
# Flax sow API is used to capture intermediate variables
self.sow('intermediates', 'dec_attention_weights', attention_weights)
return self.dense(X), state
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class TransformerDecoder(d2l.AttentionDecoder):
def __init__(self, vocab_size, key_size, query_size, value_size,
num_hiddens, norm_shape, ffn_num_hiddens, num_heads,
num_blks, dropout):
super().__init__()
self.num_hiddens = num_hiddens
self.num_blks = num_blks
self.embedding = tf.keras.layers.Embedding(vocab_size, num_hiddens)
self.pos_encoding = d2l.PositionalEncoding(num_hiddens, dropout)
self.blks = [TransformerDecoderBlock(
key_size, query_size, value_size, num_hiddens, norm_shape,
ffn_num_hiddens, num_heads, dropout, i)
for i in range(num_blks)]
self.dense = tf.keras.layers.Dense(vocab_size)
def init_state(self, enc_outputs, enc_valid_lens):
return [enc_outputs, enc_valid_lens, [None] * self.num_blks]
def call(self, X, state, **kwargs):
X = self.pos_encoding(self.embedding(X) * tf.math.sqrt(
tf.cast(self.num_hiddens, dtype=tf.float32)), **kwargs)
# 2 attention layers in decoder
self._attention_weights = [[None] * len(self.blks) for _ in range(2)]
for i, blk in enumerate(self.blks):
X, state = blk(X, state, **kwargs)
# Decoder self-attention weights
self._attention_weights[0][i] = (
blk.attention1.attention.attention_weights)
# Encoder-decoder attention weights
self._attention_weights[1][i] = (
blk.attention2.attention.attention_weights)
return self.dense(X), state
@property
def attention_weights(self):
return self._attention_weights
.. raw:: html
.. raw:: html
学習
----
Transformerアーキテクチャに従って
エンコーダ–デコーダモデルを実装しよう。 ここでは、
TransformerエンコーダとTransformerデコーダの両方が
4ヘッド注意を用いた2層構成であるとする。
:numref:`sec_seq2seq_training` と同様に、 英仏機械翻訳データセット上で
系列変換学習のために Transformerモデルを学習する。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.MTFraEng(batch_size=128)
num_hiddens, num_blks, dropout = 256, 2, 0.2
ffn_num_hiddens, num_heads = 64, 4
if tab.selected('tensorflow'):
key_size, query_size, value_size = 256, 256, 256
norm_shape = [2]
if tab.selected('pytorch', 'mxnet', 'jax'):
encoder = TransformerEncoder(
len(data.src_vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout)
decoder = TransformerDecoder(
len(data.tgt_vocab), num_hiddens, ffn_num_hiddens, num_heads,
num_blks, dropout)
if tab.selected('mxnet', 'pytorch'):
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''],
lr=0.001)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
if tab.selected('jax'):
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''],
lr=0.001, training=True)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1, num_gpus=1)
if tab.selected('tensorflow'):
with d2l.try_gpu():
encoder = TransformerEncoder(
len(data.src_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, num_blks, dropout)
decoder = TransformerDecoder(
len(data.tgt_vocab), key_size, query_size, value_size, num_hiddens,
norm_shape, ffn_num_hiddens, num_heads, num_blks, dropout)
model = d2l.Seq2Seq(encoder, decoder, tgt_pad=data.tgt_vocab[''],
lr=0.001)
trainer = d2l.Trainer(max_epochs=30, gradient_clip_val=1)
trainer.fit(model, data)
.. figure:: output_transformer_0594a4_196_0.svg
学習後、 Transformerモデルを用いて
いくつかの英語文をフランス語に翻訳し、そのBLEUスコアを計算する。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
engs = ['go .', 'i lost .', 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
if tab.selected('pytorch', 'mxnet', 'tensorflow'):
preds, _ = model.predict_step(
data.build(engs, fras), d2l.try_gpu(), data.num_steps)
if tab.selected('jax'):
preds, _ = model.predict_step(
trainer.state.params, data.build(engs, fras), data.num_steps)
for en, fr, p in zip(engs, fras, preds):
translation = []
for token in data.tgt_vocab.to_tokens(p):
if token == '':
break
translation.append(token)
print(f'{en} => {translation}, bleu,'
f'{d2l.bleu(" ".join(translation), fr, k=2):.3f}')
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
go . => ['va', '!'], bleu,1.000
i lost . => ["j'ai", 'perdu', '.'], bleu,1.000
he's calm . => ['', '.'], bleu,0.000
i'm home . => ['je', 'suis', 'chez', 'moi', '.'], bleu,1.000
最後の英語文をフランス語に翻訳するときの
Transformerの注意重みを可視化する。 エンコーダ自己注意重みの形状は
(エンコーダ層数、注意ヘッド数、\ ``num_steps``
またはクエリ数、\ ``num_steps`` またはキー・値ペア数)である。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
_, dec_attention_weights = model.predict_step(
data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)
enc_attention_weights = d2l.concat(model.encoder.attention_weights, 0)
shape = (num_blks, num_heads, -1, data.num_steps)
enc_attention_weights = d2l.reshape(enc_attention_weights, shape)
d2l.check_shape(enc_attention_weights,
(num_blks, num_heads, data.num_steps, data.num_steps))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
_, dec_attention_weights = model.predict_step(
data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)
enc_attention_weights = d2l.concat(model.encoder.attention_weights, 0)
shape = (num_blks, num_heads, -1, data.num_steps)
enc_attention_weights = d2l.reshape(enc_attention_weights, shape)
d2l.check_shape(enc_attention_weights,
(num_blks, num_heads, data.num_steps, data.num_steps))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
_, (dec_attention_weights, enc_attention_weights) = model.predict_step(
trainer.state.params, data.build([engs[-1]], [fras[-1]]),
data.num_steps, True)
enc_attention_weights = d2l.concat(enc_attention_weights, 0)
shape = (num_blks, num_heads, -1, data.num_steps)
enc_attention_weights = d2l.reshape(enc_attention_weights, shape)
d2l.check_shape(enc_attention_weights,
(num_blks, num_heads, data.num_steps, data.num_steps))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
_, dec_attention_weights = model.predict_step(
data.build([engs[-1]], [fras[-1]]), d2l.try_gpu(), data.num_steps, True)
enc_attention_weights = d2l.concat(model.encoder.attention_weights, 0)
shape = (num_blks, num_heads, -1, data.num_steps)
enc_attention_weights = d2l.reshape(enc_attention_weights, shape)
d2l.check_shape(enc_attention_weights,
(num_blks, num_heads, data.num_steps, data.num_steps))
.. raw:: html
.. raw:: html
エンコーダ自己注意では、 クエリとキーの両方が同じ入力系列から来る。
パディングトークンは意味を持たないため、
入力系列の有効長が指定されていれば、
どのクエリもパディングトークンの位置には注意を向けない。 以下では、
マルチヘッド注意重みの2層を 行ごとに示す。 各ヘッドは、
クエリ、キー、値の別々の表現部分空間に基づいて 独立に注意を向ける。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
enc_attention_weights.cpu(), xlabel='Key positions',
ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
.. figure:: output_transformer_0594a4_217_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
enc_attention_weights, xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_0594a4_220_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
enc_attention_weights, xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_0594a4_223_0.svg
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
enc_attention_weights, xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_0594a4_226_0.svg
.. raw:: html
.. raw:: html
デコーダ自己注意重みとエンコーダ–デコーダ注意重みを可視化するには、
さらにデータ操作が必要である。 たとえば、
マスクされた注意重みを0で埋める。 なお、 デコーダ自己注意重みと
エンコーダ–デコーダ注意重みは どちらも同じクエリを持つ。 すなわち、
系列開始トークンの後に 出力トークン、場合によっては
系列終了トークンが続く。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dec_attention_weights_2d = [head[0].tolist()
for step in dec_attention_weights
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = d2l.tensor(
pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
shape = (-1, 2, num_blks, num_heads, data.num_steps)
dec_attention_weights = d2l.reshape(dec_attention_weights_filled, shape)
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.permute(1, 2, 3, 0, 4)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dec_attention_weights_2d = [d2l.tensor(head[0]).tolist()
for step in dec_attention_weights
for attn in step for blk in attn for head in blk]
dec_attention_weights_filled = d2l.tensor(
pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
dec_attention_weights = d2l.reshape(dec_attention_weights_filled, (
-1, 2, num_blks, num_heads, data.num_steps))
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.transpose(1, 2, 3, 0, 4)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dec_attention_weights_2d = [head[0].tolist() for step in dec_attention_weights
for attn in step
for blk in attn for head in blk]
dec_attention_weights_filled = d2l.tensor(
pd.DataFrame(dec_attention_weights_2d).fillna(0.0).values)
dec_attention_weights = dec_attention_weights_filled.reshape(
(-1, 2, num_blks, num_heads, data.num_steps))
dec_self_attention_weights, dec_inter_attention_weights = \
dec_attention_weights.transpose(1, 2, 3, 0, 4)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
dec_attention_weights_2d = [head[0] for step in dec_attention_weights
for attn in step
for blk in attn for head in blk]
dec_attention_weights_filled = tf.convert_to_tensor(
np.asarray(pd.DataFrame(dec_attention_weights_2d).fillna(
0.0).values).astype(np.float32))
dec_attention_weights = tf.reshape(dec_attention_weights_filled, shape=(
-1, 2, num_blks, num_heads, data.num_steps))
dec_self_attention_weights, dec_inter_attention_weights = tf.transpose(
dec_attention_weights, perm=(1, 2, 3, 0, 4))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.check_shape(dec_self_attention_weights,
(num_blks, num_heads, data.num_steps, data.num_steps))
d2l.check_shape(dec_inter_attention_weights,
(num_blks, num_heads, data.num_steps, data.num_steps))
デコーダ自己注意の自己回帰性のため、
どのクエリもクエリ位置より後のキー・値ペアには注意を向けない。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
dec_self_attention_weights[:, :, :, :],
xlabel='Key positions', ylabel='Query positions',
titles=['Head %d' % i for i in range(1, 5)], figsize=(7, 3.5))
.. figure:: output_transformer_0594a4_246_0.svg
エンコーダ自己注意の場合と同様に、 入力系列の有効長を指定することで、
出力系列からのどのクエリも
入力系列中のパディングトークンには注意を向けない。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
d2l.show_heatmaps(
dec_inter_attention_weights, xlabel='Key positions',
ylabel='Query positions', titles=['Head %d' % i for i in range(1, 5)],
figsize=(7, 3.5))
.. figure:: output_transformer_0594a4_248_0.svg
Transformerアーキテクチャは もともと系列変換学習のために提案されたが、
本書の後半で見るように、 Transformerエンコーダまたは
Transformerデコーダのいずれか一方が さまざまな深層学習タスクで
個別に使われることが多い。
要約
----
Transformerはエンコーダ–デコーダアーキテクチャの一例であるが、
実際にはエンコーダまたはデコーダのどちらか一方だけを使うこともできる。
Transformerアーキテクチャでは、マルチヘッド自己注意が
入力系列と出力系列の表現に用いられるが、
デコーダではマスク付き版によって自己回帰性を保たなければならない。
Transformerにおける残差接続と層正規化は、
非常に深いモデルを学習するうえで重要である。
Transformerモデルの位置ごとのフィードフォワードネットワークは、
同じMLPを用いてすべての系列位置の表現を変換する。
演習
----
1. 実験でより深いTransformerを学習せよ。学習速度と翻訳性能にどのような影響があるか。
2. Transformerでスケールド・ドット積注意を加法注意に置き換えるのはよい考えか。なぜか。
3. 言語モデル化では、Transformerのエンコーダ、デコーダ、あるいは両方のどれを使うべきか。この方法をどのように設計するか。
4. 入力系列が非常に長い場合、Transformerはどのような課題に直面するか。なぜか。
5. Transformerの計算効率とメモリ効率をどのように改善するか。ヒント::cite:t:`Tay.Dehghani.Bahri.ea.2020`
のサーベイ論文を参照されたい。