11.2. 類似度によるアテンションプーリング

ここまででアテンション機構の主要な構成要素を導入したので、今度はそれらをかなり古典的な設定、すなわちカーネル密度推定による回帰と分類 (Nadaraya, 1964, Watson, 1964) に用いてみよう。この寄り道は単に追加の背景を与えるだけである。完全に任意であり、必要なら飛ばして構わない。
Nadaraya–Watson 推定量の本質は、クエリ \(\mathbf{q}\) とキー \(\mathbf{k}\) を結び付ける何らかの類似度カーネル \(\alpha(\mathbf{q}, \mathbf{k})\) にある。代表的なカーネルには次のようなものがある。
(11.2.1)\[\begin{split}\begin{aligned} \alpha(\mathbf{q}, \mathbf{k}) & = \exp\left(-\frac{1}{2} \|\mathbf{q} - \mathbf{k}\|^2 \right) && \textrm{Gaussian;} \\ \alpha(\mathbf{q}, \mathbf{k}) & = 1 \textrm{ if } \|\mathbf{q} - \mathbf{k}\| \leq 1 && \textrm{Boxcar;} \\ \alpha(\mathbf{q}, \mathbf{k}) & = \mathop{\mathrm{max}}\left(0, 1 - \|\mathbf{q} - \mathbf{k}\|\right) && \textrm{Epanechikov.} \end{aligned}\end{split}\]

他にも多くの選択肢がある。より詳しい概説と、カーネルの選択がカーネル密度推定、しばしば Parzen Windows とも呼ばれるもの (Parzen, 1957) とどう関係するかについては、Wikipedia の記事 を参照されたい。これらのカーネルはいずれもヒューリスティックであり、調整可能である。たとえば、幅は全体としてだけでなく、各座標ごとにも調整できる。いずれにせよ、どれも回帰と分類の両方に対して次の式を導く。

(11.2.2)\[f(\mathbf{q}) = \sum_i \mathbf{v}_i \frac{\alpha(\mathbf{q}, \mathbf{k}_i)}{\sum_j \alpha(\mathbf{q}, \mathbf{k}_j)}.\]

特徴量とラベルの観測 \((\mathbf{x}_i, y_i)\) を用いる(スカラー)回帰の場合、\(\mathbf{v}_i = y_i\) はスカラー、\(\mathbf{k}_i = \mathbf{x}_i\) はベクトルであり、クエリ \(\mathbf{q}\)\(f\) を評価すべき新しい位置を表す。(多クラス)分類の場合は、\(y_i\) の one-hot エンコーディングを用いて \(\mathbf{v}_i\) を得る。この推定量の便利な性質の一つは、学習を必要としないことである。さらに、データ量の増加に応じてカーネルを適切に狭めれば、この手法は整合的であり (Mack and Silverman, 1982)、すなわち統計的に最適な解のいずれかに収束する。まずはいくつかのカーネルを見てみよう。

from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

d2l.use_svg_display()
from d2l import mxnet as d2l
from mxnet import autograd, gluon, np, npx
from mxnet.gluon import nn
npx.set_np()
d2l.use_svg_display()
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
from d2l import tensorflow as d2l
import tensorflow as tf
import numpy as np

d2l.use_svg_display()

11.2.1. カーネルとデータ

この節で定義するすべてのカーネル \(\alpha(\mathbf{k}, \mathbf{q})\)平行移動および回転に不変 である。つまり、\(\mathbf{k}\)\(\mathbf{q}\) を同じように平行移動・回転させても、\(\alpha\) の値は変わらない。簡単のため、ここではスカラー引数 \(k, q \in \mathbb{R}\) を取り、キー \(k = 0\) を原点として選ぶ。すると次のようになる。

# いくつかのカーネルを定義する
def gaussian(x):
    return d2l.exp(-x**2 / 2)

def boxcar(x):
    return d2l.abs(x) < 1.0

def constant(x):
    return 1.0 + 0 * x

def epanechikov(x):
    return torch.max(1 - d2l.abs(x), torch.zeros_like(x))
# いくつかのカーネルを定義する
def gaussian(x):
    return d2l.exp(-x**2 / 2)

def boxcar(x):
    return d2l.abs(x) < 1.0

def constant(x):
    return 1.0 + 0 * x

def epanechikov(x):
    return np.maximum(1 - d2l.abs(x), 0)
# いくつかのカーネルを定義する
def gaussian(x):
    return d2l.exp(-x**2 / 2)

def boxcar(x):
    return d2l.abs(x) < 1.0

def constant(x):
    return 1.0 + 0 * x

def epanechikov(x):
    return jnp.maximum(1 - d2l.abs(x), 0)
# いくつかのカーネルを定義する
def gaussian(x):
    return d2l.exp(-x**2 / 2)

def boxcar(x):
    return d2l.abs(x) < 1.0

def constant(x):
    return 1.0 + 0 * x

def epanechikov(x):
    return tf.maximum(1 - d2l.abs(x), 0)
fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))

kernels = (gaussian, boxcar, constant, epanechikov)
names = ('Gaussian', 'Boxcar', 'Constant', 'Epanechikov')
x = d2l.arange(-2.5, 2.5, 0.1)
for kernel, name, ax in zip(kernels, names, axes):
    ax.plot(d2l.numpy(x), d2l.numpy(kernel(x)))
    ax.set_xlabel(name)

d2l.plt.show()
../_images/output_attention-pooling_dce4fd_32_0.svg
fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))

kernels = (gaussian, boxcar, constant, epanechikov)
names = ('Gaussian', 'Boxcar', 'Constant', 'Epanechikov')
x = d2l.arange(-2.5, 2.5, 0.1)
for kernel, name, ax in zip(kernels, names, axes):
    ax.plot(d2l.numpy(x), d2l.numpy(kernel(x)))
    ax.set_xlabel(name)

d2l.plt.show()
[07:07:20] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
../_images/output_attention-pooling_dce4fd_35_1.svg
fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))

kernels = (gaussian, boxcar, constant, epanechikov)
names = ('Gaussian', 'Boxcar', 'Constant', 'Epanechikov')
x = d2l.arange(-2.5, 2.5, 0.1)
for kernel, name, ax in zip(kernels, names, axes):
    ax.plot(x, kernel(x))
    ax.set_xlabel(name)

d2l.plt.show()
../_images/output_attention-pooling_dce4fd_38_0.png
fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))

kernels = (gaussian, boxcar, constant, epanechikov)
names = ('Gaussian', 'Boxcar', 'Constant', 'Epanechikov')
x = d2l.arange(-2.5, 2.5, 0.1)
for kernel, name, ax in zip(kernels, names, axes):
    ax.plot(d2l.numpy(x), d2l.numpy(kernel(x)))
    ax.set_xlabel(name)

d2l.plt.show()
../_images/output_attention-pooling_dce4fd_41_0.svg

異なるカーネルは、範囲と滑らかさに関する異なる概念に対応する。たとえば、boxcar カーネルは距離 \(1\)(あるいは別に定義したハイパーパラメータ)以内の観測値にしか注目せず、しかもそれを無差別に行う。

Nadaraya–Watson 推定を実際に見てみるために、訓練データを定義しよう。以下では次の依存関係を用いる。

(11.2.3)\[y_i = 2\sin(x_i) + x_i + \epsilon,\]

ここで \(\epsilon\) は平均 0、分散 1 の正規分布から生成される。40 個の訓練例をサンプルする。

def f(x):
    return 2 * d2l.sin(x) + x

n = 40
x_train, _ = torch.sort(d2l.rand(n) * 5)
y_train = f(x_train) + d2l.randn(n)
x_val = d2l.arange(0, 5, 0.1)
y_val = f(x_val)
def f(x):
    return 2 * d2l.sin(x) + x

n = 40
x_train = np.sort(d2l.rand(n) * 5, axis=None)
y_train = f(x_train) + d2l.randn(n)
x_val = d2l.arange(0, 5, 0.1)
y_val = f(x_val)
def f(x):
    return 2 * d2l.sin(x) + x

n = 40
x_train = jnp.sort(jax.random.uniform(d2l.get_key(), (n,)) * 5)
y_train = f(x_train) + jax.random.normal(d2l.get_key(), (n,))
x_val = d2l.arange(0, 5, 0.1)
y_val = f(x_val)
def f(x):
    return 2 * d2l.sin(x) + x

n = 40
x_train = tf.sort(d2l.rand((n,1)) * 5, 0)
y_train = f(x_train) + d2l.normal((n, 1))
x_val = d2l.arange(0, 5, 0.1)
y_val = f(x_val)

11.2.2. Nadaraya–Watson 回帰によるアテンションプーリング

データとカーネルがそろったので、あとはカーネル回帰の推定値を計算する関数だけである。なお、簡単な診断を行うために相対的なカーネル重みも得たいので、まず訓練特徴(共変量)x_train とすべての検証特徴 x_val の間のカーネルを計算する。これにより行列が得られ、それを正規化する。これを訓練ラベル y_train と掛け合わせると推定値が得られる。

(11.1.1) のアテンションプーリングを思い出してほしい。各検証特徴をクエリとし、各訓練特徴–ラベルの組をキー–値のペアとみなす。その結果、正規化された相対カーネル重み(以下の attention_w)が アテンション重み になる。

def nadaraya_watson(x_train, y_train, x_val, kernel):
    dists = d2l.reshape(x_train, (-1, 1)) - d2l.reshape(x_val, (1, -1))
    # 各列/行はそれぞれのクエリ/キーに対応する
    k = d2l.astype(kernel(dists), d2l.float32)
    # 各クエリに対するキー方向の正規化
    attention_w = k / d2l.reduce_sum(k, 0)
    y_hat = y_train@attention_w
    return y_hat, attention_w
def nadaraya_watson(x_train, y_train, x_val, kernel):
    dists = d2l.reshape(x_train, (-1, 1)) - d2l.reshape(x_val, (1, -1))
    # 各列/行はそれぞれのクエリ/キーに対応する
    k = d2l.astype(kernel(dists), d2l.float32)
    # 各クエリに対するキー方向の正規化
    attention_w = k / d2l.reduce_sum(k, 0)
    y_hat = np.dot(y_train, attention_w)
    return y_hat, attention_w
def nadaraya_watson(x_train, y_train, x_val, kernel):
    dists = d2l.reshape(x_train, (-1, 1)) - d2l.reshape(x_val, (1, -1))
    # 各列/行はそれぞれのクエリ/キーに対応する
    k = d2l.astype(kernel(dists), d2l.float32)
    # 各クエリに対するキー方向の正規化
    attention_w = k / d2l.reduce_sum(k, 0)
    y_hat = y_train@attention_w
    return y_hat, attention_w
def nadaraya_watson(x_train, y_train, x_val, kernel):
    dists = d2l.reshape(x_train, (-1, 1)) - d2l.reshape(x_val, (1, -1))
    # 各列/行はそれぞれのクエリ/キーに対応する
    k = d2l.astype(kernel(dists), d2l.float32)
    # 各クエリに対するキー方向の正規化
    attention_w = k / d2l.reduce_sum(k, 0)
    y_hat = d2l.transpose(d2l.transpose(y_train)@attention_w)
    return y_hat, attention_w

異なるカーネルがどのような推定を生み出すか見てみよう。

def plot(x_train, y_train, x_val, y_val, kernels, names, attention=False):
    fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))
    for kernel, name, ax in zip(kernels, names, axes):
        y_hat, attention_w = nadaraya_watson(x_train, y_train, x_val, kernel)
        if attention:
            pcm = ax.imshow(d2l.numpy(attention_w), cmap='Reds')
        else:
            ax.plot(x_val, y_hat)
            ax.plot(x_val, y_val, 'm--')
            ax.plot(x_train, y_train, 'o', alpha=0.5);
        ax.set_xlabel(name)
        if not attention:
            ax.legend(['y_hat', 'y'])
    if attention:
        fig.colorbar(pcm, ax=axes, shrink=0.7)
def plot(x_train, y_train, x_val, y_val, kernels, names, attention=False):
    fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))
    for kernel, name, ax in zip(kernels, names, axes):
        y_hat, attention_w = nadaraya_watson(x_train, y_train, x_val, kernel)
        if attention:
            pcm = ax.imshow(d2l.numpy(attention_w), cmap='Reds')
        else:
            ax.plot(x_val, y_hat)
            ax.plot(x_val, y_val, 'm--')
            ax.plot(x_train, y_train, 'o', alpha=0.5);
        ax.set_xlabel(name)
        if not attention:
            ax.legend(['y_hat', 'y'])
    if attention:
        fig.colorbar(pcm, ax=axes, shrink=0.7)
def plot(x_train, y_train, x_val, y_val, kernels, names, attention=False):
    fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))
    for kernel, name, ax in zip(kernels, names, axes):
        y_hat, attention_w = nadaraya_watson(x_train, y_train, x_val, kernel)
        if attention:
            pcm = ax.imshow(attention_w, cmap='Reds')
        else:
            ax.plot(x_val, y_hat)
            ax.plot(x_val, y_val, 'm--')
            ax.plot(x_train, y_train, 'o', alpha=0.5);
        ax.set_xlabel(name)
        if not attention:
            ax.legend(['y_hat', 'y'])
    if attention:
        fig.colorbar(pcm, ax=axes, shrink=0.7)
def plot(x_train, y_train, x_val, y_val, kernels, names, attention=False):
    fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))
    for kernel, name, ax in zip(kernels, names, axes):
        y_hat, attention_w = nadaraya_watson(x_train, y_train, x_val, kernel)
        if attention:
            pcm = ax.imshow(d2l.numpy(attention_w), cmap='Reds')
        else:
            ax.plot(x_val, y_hat)
            ax.plot(x_val, y_val, 'm--')
            ax.plot(x_train, y_train, 'o', alpha=0.5);
        ax.set_xlabel(name)
        if not attention:
            ax.legend(['y_hat', 'y'])
    if attention:
        fig.colorbar(pcm, ax=axes, shrink=0.7)
plot(x_train, y_train, x_val, y_val, kernels, names)
../_images/output_attention-pooling_dce4fd_89_0.svg

まず目につくのは、3 つの非自明なカーネル(Gaussian、Boxcar、Epanechikov)が、真の関数からそれほど離れていない、かなり実用的な推定を与えていることである。自明な推定 \(f(x) = \frac{1}{n} \sum_i y_i\) に帰着する constant カーネルだけが、かなり非現実的な結果を生む。アテンションの重み付けをもう少し詳しく見てみよう。

plot(x_train, y_train, x_val, y_val, kernels, names, attention=True)
../_images/output_attention-pooling_dce4fd_91_0.svg

この可視化から、Gaussian、Boxcar、Epanechikov の推定が非常によく似ている理由がはっきり分かる。結局のところ、カーネルの関数形は異なっていても、非常によく似たアテンション重みから導されているからである。ここで疑問になるのは、これが常に成り立つのかということである。

11.2.3. アテンションプーリングの適応

Gaussian カーネルを別の幅のものに置き換えることができる。つまり、 \(\alpha(\mathbf{q}, \mathbf{k}) = \exp\left(-\frac{1}{2 \sigma^2} \|\mathbf{q} - \mathbf{k}\|^2 \right)\) を使い、\(\sigma^2\) でカーネルの幅を決めることができる。これが結果に影響するか見てみよう。

sigmas = (0.1, 0.2, 0.5, 1)
names = ['Sigma ' + str(sigma) for sigma in sigmas]

def gaussian_with_width(sigma):
    return (lambda x: d2l.exp(-x**2 / (2*sigma**2)))

kernels = [gaussian_with_width(sigma) for sigma in sigmas]
plot(x_train, y_train, x_val, y_val, kernels, names)
../_images/output_attention-pooling_dce4fd_93_0.svg

明らかに、カーネルが狭いほど推定は滑らかでなくなる。同時に、局所的な変動にはよりよく適応する。対応するアテンション重みを見てみよう。

plot(x_train, y_train, x_val, y_val, kernels, names, attention=True)
../_images/output_attention-pooling_dce4fd_95_0.svg

予想どおり、カーネルが狭いほど、大きなアテンション重みを持つ範囲も狭くなる。また、同じ幅を選ぶことが必ずしも理想的ではないことも明らかである。実際、Silverman (1986) は局所密度に依存するヒューリスティックを提案した。このような「工夫」は他にも多数提案されている。たとえば、Norelli et al. (2022) は、クロスモーダルな画像・テキスト表現を設計するために、類似した最近傍補間技術を用いた。

この方法が半世紀以上前のものであるにもかかわらず、なぜここまで詳しく扱うのか、不思議に思う読者もいるかもしれない。第一に、現代のアテンション機構の最も初期の先駆けの一つだからである。第二に、可視化に非常に適している。第三に、そして同じくらい重要であるが、手作りのアテンション機構の限界を示してくれる。より良い戦略は、クエリとキーの表現を学習することで機構そのものを 学習する ことである。これが次の節で取り組む内容である。

11.2.4. まとめ

Nadaraya–Watson カーネル回帰は、現在のアテンション機構の初期の先駆けである。
分類にも回帰にも、ほとんどあるいはまったく学習や調整を必要とせず、そのまま使える。
アテンション重みは、クエリとキーの類似度(または距離)と、どれだけ類似した観測が利用可能かに応じて割り当てられる。

11.2.5. 演習

  1. Parzen windows による密度推定は \(\hat{p}(\mathbf{x}) = \frac{1}{n} \sum_i k(\mathbf{x}, \mathbf{x}_i)\) で与えられる。二値分類に対して、Parzen windows から得られる関数 \(\hat{p}(\mathbf{x}, y=1) - \hat{p}(\mathbf{x}, y=-1)\) が Nadaraya–Watson 分類と等価であることを証明せよ。

  2. Nadaraya–Watson 回帰において、カーネル幅の良い値を学習するための確率的勾配降下法を実装せよ。

    1. 上の推定値をそのまま使って \((f(\mathbf{x_i}) - y_i)^2\) を直接最小化するとどうなるか? ヒント: \(y_i\)\(f\) を計算するために使われる項の一部である。

    2. \(f(\mathbf{x}_i)\) の推定から \((\mathbf{x}_i, y_i)\) を除外し、カーネル幅について最適化せよ。それでも過学習は観測されるか?

  3. すべての \(\mathbf{x}\) が単位球面上にある、すなわちすべて \(\|\mathbf{x}\| = 1\) を満たすと仮定する。指数関数内の \(\|\mathbf{x} - \mathbf{x}_i\|^2\) の項を簡単化できるか? ヒント: 後で、ドット積アテンションと非常に密接に関係していることが分かる。

  4. Mack and Silverman (1982) が Nadaraya–Watson 推定の整合性を証明したことを思い出してほしい。データが増えるにつれて、アテンション機構のスケールをどのくらいの速さで小さくすべきだろうか。 答えの直感も述べよ。データの次元に依存するか? どのように依存するか?