11.2. 類似度によるアテンションプーリング

ここまででアテンション機構の主要な構成要素を導入したので、今度はそれらをかなり古典的な設定、すなわちカーネル密度推定による回帰と分類 (Nadaraya, 1964, Watson, 1964) に用いてみよう。この寄り道は単に追加の背景を与えるだけである。完全に任意であり、必要なら飛ばして構わない。
Nadaraya–Watson 推定量の本質は、クエリ \(\mathbf{q}\) とキー \(\mathbf{k}\) を結び付ける何らかの類似度カーネル \(\alpha(\mathbf{q}, \mathbf{k})\) にある。代表的なカーネルには次のようなものがある。
(11.2.1)\[\begin{split}\begin{aligned} \alpha(\mathbf{q}, \mathbf{k}) & = \exp\left(-\frac{1}{2} \|\mathbf{q} - \mathbf{k}\|^2 \right) && \textrm{Gaussian;} \\ \alpha(\mathbf{q}, \mathbf{k}) & = 1 \textrm{ if } \|\mathbf{q} - \mathbf{k}\| \leq 1 && \textrm{Boxcar;} \\ \alpha(\mathbf{q}, \mathbf{k}) & = \mathop{\mathrm{max}}\left(0, 1 - \|\mathbf{q} - \mathbf{k}\|\right) && \textrm{Epanechikov.} \end{aligned}\end{split}\]

他にも多くの選択肢がある。より詳しい概説と、カーネルの選択がカーネル密度推定、しばしば Parzen Windows とも呼ばれるもの (Parzen, 1957) とどう関係するかについては、Wikipedia の記事 を参照されたい。これらのカーネルはいずれもヒューリスティックであり、調整可能である。たとえば、幅は全体としてだけでなく、各座標ごとにも調整できる。いずれにせよ、どれも回帰と分類の両方に対して次の式を導く。

(11.2.2)\[f(\mathbf{q}) = \sum_i \mathbf{v}_i \frac{\alpha(\mathbf{q}, \mathbf{k}_i)}{\sum_j \alpha(\mathbf{q}, \mathbf{k}_j)}.\]

特徴量とラベルの観測 \((\mathbf{x}_i, y_i)\) を用いる(スカラー)回帰の場合、\(\mathbf{v}_i = y_i\) はスカラー、\(\mathbf{k}_i = \mathbf{x}_i\) はベクトルであり、クエリ \(\mathbf{q}\)\(f\) を評価すべき新しい位置を表す。(多クラス)分類の場合は、\(y_i\) の one-hot エンコーディングを用いて \(\mathbf{v}_i\) を得る。この推定量の便利な性質の一つは、学習を必要としないことである。さらに、データ量の増加に応じてカーネルを適切に狭めれば、この手法は整合的であり (Mack and Silverman, 1982)、すなわち統計的に最適な解のいずれかに収束する。まずはいくつかのカーネルを見てみよう。

%load_ext d2lbook.tab
tab.interact_select('mxnet', 'pytorch', 'tensorflow', 'jax')
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

d2l.use_svg_display()
from d2l import mxnet as d2l
from mxnet import autograd, gluon, np, npx
from mxnet.gluon import nn
npx.set_np()
d2l.use_svg_display()
from d2l import jax as d2l
import jax
from jax import numpy as jnp
from flax import linen as nn
from d2l import tensorflow as d2l
import tensorflow as tf
import numpy as np

d2l.use_svg_display()

11.2.1. カーネルとデータ

この節で定義するすべてのカーネル \(\alpha(\mathbf{k}, \mathbf{q})\)平行移動および回転に不変 である。つまり、\(\mathbf{k}\)\(\mathbf{q}\) を同じように平行移動・回転させても、\(\alpha\) の値は変わらない。簡単のため、ここではスカラー引数 \(k, q \in \mathbb{R}\) を取り、キー \(k = 0\) を原点として選ぶ。すると次のようになる。

# Define some kernels
def gaussian(x):
    return d2l.exp(-x**2 / 2)

def boxcar(x):
    return d2l.abs(x) < 1.0

def constant(x):
    return 1.0 + 0 * x

if tab.selected('pytorch'):
    def epanechikov(x):
        return torch.max(1 - d2l.abs(x), torch.zeros_like(x))
if tab.selected('mxnet'):
    def epanechikov(x):
        return np.maximum(1 - d2l.abs(x), 0)
if tab.selected('tensorflow'):
    def epanechikov(x):
        return tf.maximum(1 - d2l.abs(x), 0)
if tab.selected('jax'):
    def epanechikov(x):
        return jnp.maximum(1 - d2l.abs(x), 0)
fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))

kernels = (gaussian, boxcar, constant, epanechikov)
names = ('Gaussian', 'Boxcar', 'Constant', 'Epanechikov')
x = d2l.arange(-2.5, 2.5, 0.1)
for kernel, name, ax in zip(kernels, names, axes):
    if tab.selected('pytorch', 'mxnet', 'tensorflow'):
        ax.plot(d2l.numpy(x), d2l.numpy(kernel(x)))
    if tab.selected('jax'):
        ax.plot(x, kernel(x))
    ax.set_xlabel(name)

d2l.plt.show()

異なるカーネルは、範囲と滑らかさに関する異なる概念に対応する。たとえば、boxcar カーネルは距離 \(1\)(あるいは別に定義したハイパーパラメータ)以内の観測値にしか注目せず、しかもそれを無差別に行う。

Nadaraya–Watson 推定を実際に見てみるために、訓練データを定義しよう。以下では次の依存関係を用いる。

(11.2.3)\[y_i = 2\sin(x_i) + x_i + \epsilon,\]

ここで \(\epsilon\) は平均 0、分散 1 の正規分布から生成される。40 個の訓練例をサンプルする。

def f(x):
    return 2 * d2l.sin(x) + x

n = 40
if tab.selected('pytorch'):
    x_train, _ = torch.sort(d2l.rand(n) * 5)
    y_train = f(x_train) + d2l.randn(n)
if tab.selected('mxnet'):
    x_train = np.sort(d2l.rand(n) * 5, axis=None)
    y_train = f(x_train) + d2l.randn(n)
if tab.selected('tensorflow'):
    x_train = tf.sort(d2l.rand((n,1)) * 5, 0)
    y_train = f(x_train) + d2l.normal((n, 1))
if tab.selected('jax'):
    x_train = jnp.sort(jax.random.uniform(d2l.get_key(), (n,)) * 5)
    y_train = f(x_train) + jax.random.normal(d2l.get_key(), (n,))
x_val = d2l.arange(0, 5, 0.1)
y_val = f(x_val)

11.2.2. Nadaraya–Watson 回帰によるアテンションプーリング

データとカーネルがそろったので、あとはカーネル回帰の推定値を計算する関数だけである。なお、簡単な診断を行うために相対的なカーネル重みも得たいので、まず訓練特徴(共変量)x_train とすべての検証特徴 x_val の間のカーネルを計算する。これにより行列が得られ、それを正規化する。これを訓練ラベル y_train と掛け合わせると推定値が得られる。

(11.1.1) のアテンションプーリングを思い出してほしい。各検証特徴をクエリとし、各訓練特徴–ラベルの組をキー–値のペアとみなす。その結果、正規化された相対カーネル重み(以下の attention_w)が アテンション重み になる。

def nadaraya_watson(x_train, y_train, x_val, kernel):
    dists = d2l.reshape(x_train, (-1, 1)) - d2l.reshape(x_val, (1, -1))
    # Each column/row corresponds to each query/key
    k = d2l.astype(kernel(dists), d2l.float32)
    # Normalization over keys for each query
    attention_w = k / d2l.reduce_sum(k, 0)
    if tab.selected('pytorch'):
        y_hat = y_train@attention_w
    if tab.selected('mxnet'):
        y_hat = np.dot(y_train, attention_w)
    if tab.selected('tensorflow'):
        y_hat = d2l.transpose(d2l.transpose(y_train)@attention_w)
    if tab.selected('jax'):
        y_hat = y_train@attention_w
    return y_hat, attention_w

異なるカーネルがどのような推定を生み出すか見てみよう。

def plot(x_train, y_train, x_val, y_val, kernels, names, attention=False):
    fig, axes = d2l.plt.subplots(1, 4, sharey=True, figsize=(12, 3))
    for kernel, name, ax in zip(kernels, names, axes):
        y_hat, attention_w = nadaraya_watson(x_train, y_train, x_val, kernel)
        if attention:
            if tab.selected('pytorch', 'mxnet', 'tensorflow'):
                pcm = ax.imshow(d2l.numpy(attention_w), cmap='Reds')
            if tab.selected('jax'):
                pcm = ax.imshow(attention_w, cmap='Reds')
        else:
            ax.plot(x_val, y_hat)
            ax.plot(x_val, y_val, 'm--')
            ax.plot(x_train, y_train, 'o', alpha=0.5);
        ax.set_xlabel(name)
        if not attention:
            ax.legend(['y_hat', 'y'])
    if attention:
        fig.colorbar(pcm, ax=axes, shrink=0.7)
plot(x_train, y_train, x_val, y_val, kernels, names)
../_images/output_attention-pooling_dce4fd_25_0.svg

まず目につくのは、3 つの非自明なカーネル(Gaussian、Boxcar、Epanechikov)が、真の関数からそれほど離れていない、かなり実用的な推定を与えていることである。自明な推定 \(f(x) = \frac{1}{n} \sum_i y_i\) に帰着する constant カーネルだけが、かなり非現実的な結果を生む。アテンションの重み付けをもう少し詳しく見てみよう。

plot(x_train, y_train, x_val, y_val, kernels, names, attention=True)

この可視化から、Gaussian、Boxcar、Epanechikov の推定が非常によく似ている理由がはっきり分かる。結局のところ、カーネルの関数形は異なっていても、非常によく似たアテンション重みから導されているからである。ここで疑問になるのは、これが常に成り立つのかということである。

11.2.3. アテンションプーリングの適応

Gaussian カーネルを別の幅のものに置き換えることができる。つまり、 \(\alpha(\mathbf{q}, \mathbf{k}) = \exp\left(-\frac{1}{2 \sigma^2} \|\mathbf{q} - \mathbf{k}\|^2 \right)\) を使い、\(\sigma^2\) でカーネルの幅を決めることができる。これが結果に影響するか見てみよう。

sigmas = (0.1, 0.2, 0.5, 1)
names = ['Sigma ' + str(sigma) for sigma in sigmas]

def gaussian_with_width(sigma):
    return (lambda x: d2l.exp(-x**2 / (2*sigma**2)))

kernels = [gaussian_with_width(sigma) for sigma in sigmas]
plot(x_train, y_train, x_val, y_val, kernels, names)

明らかに、カーネルが狭いほど推定は滑らかでなくなる。同時に、局所的な変動にはよりよく適応する。対応するアテンション重みを見てみよう。

plot(x_train, y_train, x_val, y_val, kernels, names, attention=True)
../_images/output_attention-pooling_dce4fd_31_0.svg

予想どおり、カーネルが狭いほど、大きなアテンション重みを持つ範囲も狭くなる。また、同じ幅を選ぶことが必ずしも理想的ではないことも明らかである。実際、Silverman (1986) は局所密度に依存するヒューリスティックを提案した。このような「工夫」は他にも多数提案されている。たとえば、Norelli et al. (2022) は、クロスモーダルな画像・テキスト表現を設計するために、類似した最近傍補間技術を用いた。

この方法が半世紀以上前のものであるにもかかわらず、なぜここまで詳しく扱うのか、不思議に思う読者もいるかもしれない。第一に、これは現代のアテンション機構の最も初期の先駆けの一つだからである。第二に、可視化に非常に適している。第三に、そして同じくらい重要であるが、手作りのアテンション機構の限界を示してくれる。より良い戦略は、クエリとキーの表現を学習することで機構そのものを 学習する ことである。これが次の節で取り組む内容である。

11.2.4. まとめ

Nadaraya–Watson カーネル回帰は、現在のアテンション機構の初期の先駆けである。
分類にも回帰にも、ほとんどあるいはまったく学習や調整を必要とせず、そのまま使える。
アテンション重みは、クエリとキーの類似度(または距離)と、どれだけ類似した観測が利用可能かに応じて割り当てられる。

11.2.5. 演習

  1. Parzen windows による密度推定は \(\hat{p}(\mathbf{x}) = \frac{1}{n} \sum_i k(\mathbf{x}, \mathbf{x}_i)\) で与えられる。二値分類に対して、Parzen windows から得られる関数 \(\hat{p}(\mathbf{x}, y=1) - \hat{p}(\mathbf{x}, y=-1)\) が Nadaraya–Watson 分類と等価であることを証明せよ。

  2. Nadaraya–Watson 回帰において、カーネル幅の良い値を学習するための確率的勾配降下法を実装せよ。

    1. 上の推定値をそのまま使って \((f(\mathbf{x_i}) - y_i)^2\) を直接最小化するとどうなるか? ヒント: \(y_i\)\(f\) を計算するために使われる項の一部である。

    2. \(f(\mathbf{x}_i)\) の推定から \((\mathbf{x}_i, y_i)\) を除外し、カーネル幅について最適化せよ。それでも過学習は観測されるか?

  3. すべての \(\mathbf{x}\) が単位球面上にある、すなわちすべて \(\|\mathbf{x}\| = 1\) を満たすと仮定する。指数関数内の \(\|\mathbf{x} - \mathbf{x}_i\|^2\) の項を簡単化できるか? ヒント: これは後で、ドット積アテンションと非常に密接に関係していることが分かる。

  4. Mack and Silverman (1982) が Nadaraya–Watson 推定の整合性を証明したことを思い出してほしい。データが増えるにつれて、アテンション機構のスケールをどのくらいの速さで小さくすべきだろうか。 答えの直感も述べよ。これはデータの次元に依存するか? どのように依存するか?