11.8. 画像向けTransformer

Transformerアーキテクチャは当初、 機械翻訳に焦点を当てた 系列変換学習のために提案された。 その後、Transformerは さまざまな自然言語処理タスクにおける第一選択のモデルとして台頭した (Brown et al., 2020, Devlin et al., 2018, Radford et al., 2018, Radford et al., 2019, Raffel et al., 2020)。 しかし、コンピュータビジョンの分野では、 支配的なアーキテクチャは依然として CNNのままであった(8 章)。 当然ながら、研究者たちは Transformerモデルを画像データに適用することで より良い性能が得られるのではないかと考え始めた。 この問いは、コンピュータビジョンコミュニティに 大きな関心を呼び起こした。 最近では、Ramachandran et al. (2019) が 畳み込みを自己注意で置き換える方式を提案した。 しかし、注意機構に特殊なパターンを用いるため、 ハードウェアアクセラレータ上でモデルを大規模化しにくい。 その後、Cordonnier et al. (2020) は理論的に、 自己注意が畳み込みと同様に振る舞うよう学習できることを証明した。 実証的には、画像から \(2 \times 2\) のパッチを入力として取り出したが、 パッチサイズが小さいため、このモデルは 低解像度の画像データにしか適用できない。

パッチサイズに特別な制約を設けずに、 vision Transformers(ViT)は 画像からパッチを抽出し、 それらをTransformerエンコーダに入力して グローバルな表現を得る。 そして最終的に、その表現を分類用に変換する (Dosovitskiy et al., 2021)。 特筆すべきは、TransformerはCNNよりもスケーラビリティに優れていることである。 より大きなデータセットでより大規模なモデルを学習すると、 vision TransformerはResNetを大きく上回る。 自然言語処理におけるネットワークアーキテクチャ設計の潮流と同様に、 Transformerはコンピュータビジョンにおいてもゲームチェンジャーとなった。

from d2l import torch as d2l
import torch
from torch import nn
img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(img_size, img_size))
trainer.fit(model, data)
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(img_size, img_size))
trainer.fit(model, data)

11.8.1. モデル

図 11.8.1 は vision Transformerのモデルアーキテクチャを示している。 このアーキテクチャは、 画像をパッチ化するstem、 多層Transformerエンコーダに基づくbody、 そしてグローバル表現を 出力ラベルへ変換するheadから構成される。

../_images/vit.svg

図 11.8.1 The vision Transformer architecture. In this example, an image is split into nine patches. A special “<cls>” token and the nine flattened image patches are transformed via patch embedding and \(\mathit{n}\) Transformer encoder blocks into ten representations, respectively. The “<cls>” representation is further transformed into the output label.

高さ \(h\)、幅 \(w\)、 チャネル数 \(c\) をもつ入力画像を考える。 パッチの高さと幅をともに \(p\) とすると、 画像は \(m = hw/p^2\) 個のパッチ列に分割され、 各パッチは長さ \(cp^2\) のベクトルに平坦化される。 このようにして、画像パッチはTransformerエンコーダによって テキスト系列中のトークンと同様に扱うことができる。 特別な “<cls>”(class)トークンと \(m\) 個の平坦化された画像パッチは線形射影されて \(m+1\) 個のベクトル列となり、 学習可能な位置埋め込みが加算される。 多層Transformerエンコーダは \(m+1\) 個の入力ベクトルを 同じ長さの \(m+1\) 個の出力ベクトル表現へ変換する。 これは 図 11.7.1 における元のTransformerエンコーダと まったく同じように動作し、 正規化の位置だけが異なる。 “<cls>” トークンは自己注意を通じて すべての画像パッチに注意を向けるため(図 11.6.1 を参照)、 Transformerエンコーダ出力におけるその表現は さらに出力ラベルへ変換される。

11.8.2. パッチ埋め込み

vision Transformerを実装するには、まず 図 11.8.1 のパッチ埋め込みから始めよう。 画像をパッチに分割し、 それらの平坦化されたパッチを線形射影する操作は、 カーネルサイズとストライドサイズの両方をパッチサイズに設定した 1つの畳み込み演算として簡略化できる。

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=96, patch_size=16, num_hiddens=512):
        super().__init__()
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(img_size), _make_tuple(patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (
            img_size[1] // patch_size[1])
        self.conv = nn.LazyConv2d(num_hiddens, kernel_size=patch_size,
                                  stride=patch_size)

    def forward(self, X):
        # Output shape: (batch size, no. of patches, no. of channels)
        return self.conv(X).flatten(2).transpose(1, 2)
class PatchEmbedding(nn.Module):
    img_size: int = 96
    patch_size: int = 16
    num_hiddens: int = 512

    def setup(self):
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x
        img_size, patch_size = _make_tuple(self.img_size), _make_tuple(self.patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (
            img_size[1] // patch_size[1])
        self.conv = nn.Conv(self.num_hiddens, kernel_size=patch_size,
                            strides=patch_size, padding='SAME')

    def __call__(self, X):
        # Output shape: (batch size, no. of patches, no. of channels)
        X = self.conv(X)
        return X.reshape((X.shape[0], -1, X.shape[3]))

次の例では、高さと幅が img_size の画像を入力として、 パッチ埋め込みは (img_size//patch_size)**2 個のパッチを出力し、 それらは長さ num_hiddens のベクトルへ線形射影される。

img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens)
X = d2l.zeros(batch_size, 3, img_size, img_size)
d2l.check_shape(patch_emb(X),
                (batch_size, (img_size//patch_size)**2, num_hiddens))
img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens)
X = d2l.zeros((batch_size, img_size, img_size, 3))
output, _ = patch_emb.init_with_output(d2l.get_key(), X)
d2l.check_shape(output, (batch_size, (img_size//patch_size)**2, num_hiddens))

11.8.3. Vision Transformerエンコーダ

vision TransformerエンコーダのMLPは、 元のTransformerエンコーダの位置ごとのFFNとは少し異なる(11.7.2 章 を参照)。 第一に、ここでは活性化関数としてガウス誤差線形ユニット(GELU)を用いる。 これはReLUのより滑らかな版とみなせる (Hendrycks and Gimpel, 2016)。 第二に、正則化のために、MLP内の各全結合層の出力にドロップアウトを適用する。

class ViTMLP(nn.Module):
    def __init__(self, mlp_num_hiddens, mlp_num_outputs, dropout=0.5):
        super().__init__()
        self.dense1 = nn.LazyLinear(mlp_num_hiddens)
        self.gelu = nn.GELU()
        self.dropout1 = nn.Dropout(dropout)
        self.dense2 = nn.LazyLinear(mlp_num_outputs)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout2(self.dense2(self.dropout1(self.gelu(
            self.dense1(x)))))
class ViTMLP(nn.Module):
    mlp_num_hiddens: int
    mlp_num_outputs: int
    dropout: float = 0.5

    @nn.compact
    def __call__(self, x, training=False):
        x = nn.Dense(self.mlp_num_hiddens)(x)
        x = nn.gelu(x)
        x = nn.Dropout(self.dropout, deterministic=not training)(x)
        x = nn.Dense(self.mlp_num_outputs)(x)
        x = nn.Dropout(self.dropout, deterministic=not training)(x)
        return x

vision Transformerエンコーダブロックの実装は、 図 11.8.1 における事前正規化の設計に従っている。 ここでは、正規化はマルチヘッド注意またはMLPの直前に適用される。 図 11.7.1 の「add & norm」のような事後正規化では、 正規化は残差接続の直後に置かれるのに対し、 事前正規化はTransformerの学習をより効果的または効率的にする (Baevski and Auli, 2018, Wang et al., 2019, Xiong et al., 2020)

class ViTBlock(nn.Module):
    def __init__(self, num_hiddens, norm_shape, mlp_num_hiddens,
                 num_heads, dropout, use_bias=False):
        super().__init__()
        self.ln1 = nn.LayerNorm(norm_shape)
        self.attention = d2l.MultiHeadAttention(num_hiddens, num_heads,
                                                dropout, use_bias)
        self.ln2 = nn.LayerNorm(norm_shape)
        self.mlp = ViTMLP(mlp_num_hiddens, num_hiddens, dropout)

    def forward(self, X, valid_lens=None):
        X = X + self.attention(*([self.ln1(X)] * 3), valid_lens)
        return X + self.mlp(self.ln2(X))
class ViTBlock(nn.Module):
    num_hiddens: int
    mlp_num_hiddens: int
    num_heads: int
    dropout: float
    use_bias: bool = False

    def setup(self):
        self.attention = d2l.MultiHeadAttention(self.num_hiddens, self.num_heads,
                                                self.dropout, self.use_bias)
        self.mlp = ViTMLP(self.mlp_num_hiddens, self.num_hiddens, self.dropout)

    @nn.compact
    def __call__(self, X, valid_lens=None, training=False):
        X = X + self.attention(*([nn.LayerNorm()(X)] * 3),
                               valid_lens, training=training)[0]
        return X + self.mlp(nn.LayerNorm()(X), training=training)

11.7.4 章 と同様に、 vision Transformerエンコーダブロックは入力の形状を変えない。

X = d2l.ones((2, 100, 24))
encoder_blk = ViTBlock(24, 24, 48, 8, 0.5)
encoder_blk.eval()
d2l.check_shape(encoder_blk(X), X.shape)
X = d2l.ones((2, 100, 24))
encoder_blk = ViTBlock(24, 48, 8, 0.5)
d2l.check_shape(encoder_blk.init_with_output(d2l.get_key(), X)[0], X.shape)

11.8.4. 全体をまとめる

以下のvision Transformerの順伝播は単純である。 まず、入力画像は PatchEmbedding インスタンスに与えられ、 その出力は “<cls>” トークン埋め込みと連結される。 それらはドロップアウトの前に、学習可能な位置埋め込みと加算される。 次に、その出力は ViTBlock クラスのインスタンスを num_blks 個積み重ねたTransformerエンコーダに入力される。 最後に、“<cls>” トークンの表現がネットワークのheadによって射影される。

class ViT(d2l.Classifier):
    """Vision Transformer."""
    def __init__(self, img_size, patch_size, num_hiddens, mlp_num_hiddens,
                 num_heads, num_blks, emb_dropout, blk_dropout, lr=0.1,
                 use_bias=False, num_classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.patch_embedding = PatchEmbedding(
            img_size, patch_size, num_hiddens)
        self.cls_token = nn.Parameter(d2l.zeros(1, 1, num_hiddens))
        num_steps = self.patch_embedding.num_patches + 1  # Add the cls token
        # Positional embeddings are learnable
        self.pos_embedding = nn.Parameter(
            torch.randn(1, num_steps, num_hiddens))
        self.dropout = nn.Dropout(emb_dropout)
        self.blks = nn.Sequential()
        for i in range(num_blks):
            self.blks.add_module(f"{i}", ViTBlock(
                num_hiddens, num_hiddens, mlp_num_hiddens,
                num_heads, blk_dropout, use_bias))
        self.head = nn.Sequential(nn.LayerNorm(num_hiddens),
                                  nn.Linear(num_hiddens, num_classes))

    def forward(self, X):
        X = self.patch_embedding(X)
        X = d2l.concat((self.cls_token.expand(X.shape[0], -1, -1), X), 1)
        X = self.dropout(X + self.pos_embedding)
        for blk in self.blks:
            X = blk(X)
        return self.head(X[:, 0])
class ViT(d2l.Classifier):
    """Vision Transformer."""
    img_size: int
    patch_size: int
    num_hiddens: int
    mlp_num_hiddens: int
    num_heads: int
    num_blks: int
    emb_dropout: float
    blk_dropout: float
    lr: float = 0.1
    use_bias: bool = False
    num_classes: int = 10
    training: bool = False

    def setup(self):
        self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size,
                                              self.num_hiddens)
        self.cls_token = self.param('cls_token', nn.initializers.zeros,
                                    (1, 1, self.num_hiddens))
        num_steps = self.patch_embedding.num_patches + 1  # Add the cls token
        # Positional embeddings are learnable
        self.pos_embedding = self.param('pos_embed', nn.initializers.normal(),
                                        (1, num_steps, self.num_hiddens))
        self.blks = [ViTBlock(self.num_hiddens, self.mlp_num_hiddens,
                              self.num_heads, self.blk_dropout, self.use_bias)
                    for _ in range(self.num_blks)]
        self.head = nn.Sequential([nn.LayerNorm(), nn.Dense(self.num_classes)])

    @nn.compact
    def __call__(self, X):
        X = self.patch_embedding(X)
        X = d2l.concat((jnp.tile(self.cls_token, (X.shape[0], 1, 1)), X), 1)
        X = nn.Dropout(emb_dropout, deterministic=not self.training)(X + self.pos_embedding)
        for blk in self.blks:
            X = blk(X, training=self.training)
        return self.head(X[:, 0])

11.8.5. 学習

Fashion-MNISTデータセットでvision Transformerを学習するのは、 8 章 でCNNを学習したときと同じである。

img_size, patch_size = 96, 16
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens, num_heads,
            num_blks, emb_dropout, blk_dropout, lr)
trainer = d2l.Trainer(max_epochs=10, num_gpus=1)
data = d2l.FashionMNIST(batch_size=128, resize=(img_size, img_size))
trainer.fit(model, data)
../_images/output_vision-transformer_5c928f_70_0.svg

11.8.6. 要約と考察

Fashion-MNISTのような小規模データセットでは、 実装したvision Transformerが 8.6 章 のResNetを上回らないことに 気づいたかもしれない。 同様の観察は、ImageNetデータセット(120万枚の画像)でも成り立つ。 これは、Transformerには 畳み込みにおける有用な帰納バイアス、 たとえば平行移動不変性や局所性(7.1 章)が 欠けている ためである。 しかし、より大きなデータセット(たとえば3億枚の画像)で より大規模なモデルを学習すると状況は変わり、 その場合、vision Transformerは画像分類でResNetを大きく上回り、 スケーラビリティにおけるTransformerの本質的な優位性を示している (Dosovitskiy et al., 2021)。 vision Transformerの導入は、 画像データをモデル化するためのネットワーク設計の潮流を変えた。 その後すぐに、DeiTのデータ効率の高い学習戦略によって ImageNetデータセットで有効であることが示された (Touvron et al., 2021)。 しかし、自己注意の二次計算量 (11.6 章) のため、Transformerアーキテクチャは 高解像度画像にはあまり適していない。 コンピュータビジョンにおける汎用バックボーンネットワークを目指して、 Swin Transformerは画像サイズに対する二次的な計算複雑性を (11.6.2 章) 解消し、畳み込みに似た事前知識を再導入した。 その結果、Transformerの適用範囲は 画像分類を超えたさまざまなコンピュータビジョンタスクへと広がり、 最先端の結果を達成している (Liu et al., 2021)

11.8.7. 演習

  1. img_size の値は学習時間にどのように影響するか。

  2. “<cls>” トークン表現を出力へ射影する代わりに、平均化したパッチ表現をどのように射影するか。これを実装し、精度への影響を調べよ。

  3. ハイパーパラメータを調整して、vision Transformerの精度を改善できるか。