9.4. 再帰型ニューラルネットワーク

9.3 章 では、言語モデリングのためのマルコフモデルと \(n\)-gram を説明した。そこでは、時刻 \(t\) におけるトークン \(x_t\) の条件付き確率は、直前の \(n-1\) 個のトークンのみに依存する。 時刻 \(t-(n-1)\) より前のトークンが \(x_t\) に及ぼしうる影響を取り込みたい場合は、 \(n\) を大きくする必要がある。 しかし、その場合、モデルパラメータの数もそれに伴って指数的に増加する。というのも、語彙集合 \(\mathcal{V}\) に対して \(|\mathcal{V}|^n\) 個の数を保持しなければならないからである。 したがって、\(P(x_t \mid x_{t-1}, \ldots, x_{1})\) を直接モデル化するよりも、潜在変数モデルを用いる方が望ましい。

(9.4.1)\[P(x_t \mid x_{t-1}, \ldots, x_1) \approx P(x_t \mid h_{t-1}),\]

ここで \(h_{t-1}\) は、時刻 \(t-1\) までの系列情報を保持する 隠れ状態 である。 一般に、 任意の時刻 \(t\) における隠れ状態は、現在の入力 \(x_{t}\) と前の隠れ状態 \(h_{t-1}\) の両方に基づいて計算できる。

(9.4.2)\[h_t = f(x_{t}, h_{t-1}).\]

(9.4.2) において十分に強力な関数 \(f\) を用いれば、この潜在変数モデルは近似ではない。結局のところ、\(h_t\) はこれまでに観測したすべてのデータを単に保持していてもよいからである。 しかし、その場合、計算と記憶の両方が高コストになる可能性がある。

5 章 で、隠れユニットを持つ隠れ層について説明したことを思い出してほしい。 ここで重要なのは、 隠れ層と隠れ状態は、まったく異なる概念だということである。 隠れ層は、説明したように、入力から出力へ至る経路の途中で外から見えない層である。 一方、隠れ状態は技術的には、ある時点で行う処理に対する 入力 であり、 過去の時刻のデータを見て初めて計算できる。

再帰型ニューラルネットワーク(RNN)は、隠れ状態を持つニューラルネットワークである。RNN モデルを導入する前に、まず 5.1 章 で導入した MLP モデルを振り返ろう。

from d2l import torch as d2l
import torch
from d2l import mxnet as d2l
from mxnet import np, npx
npx.set_np()
from d2l import jax as d2l
import jax
from jax import numpy as jnp
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
from d2l import tensorflow as d2l
import tensorflow as tf

9.4.1. 隠れ状態を持たないニューラルネットワーク

単一の隠れ層を持つ MLP を見てみよう。 隠れ層の活性化関数を \(\phi\) とする。 バッチサイズが \(n\)、入力次元が \(d\) のミニバッチの例 \(\mathbf{X} \in \mathbb{R}^{n \times d}\) が与えられたとき、隠れ層の出力 \(\mathbf{H} \in \mathbb{R}^{n \times h}\) は次のように計算される。

(9.4.3)\[\mathbf{H} = \phi(\mathbf{X} \mathbf{W}_{\textrm{xh}} + \mathbf{b}_\textrm{h}).\]

(9.4.3) では、隠れ層に対して重みパラメータ \(\mathbf{W}_{\textrm{xh}} \in \mathbb{R}^{d \times h}\)、バイアスパラメータ \(\mathbf{b}_\textrm{h} \in \mathbb{R}^{1 \times h}\)、および隠れユニット数 \(h\) を用いている。 このため、加算の際にはブロードキャスト(2.1.4 章 を参照)を適用する。 次に、隠れ層の出力 \(\mathbf{H}\) を出力層の入力として用いる。出力層は次式で与えられる。

(9.4.4)\[\mathbf{O} = \mathbf{H} \mathbf{W}_{\textrm{hq}} + \mathbf{b}_\textrm{q},\]

ここで \(\mathbf{O} \in \mathbb{R}^{n \times q}\) は出力変数、\(\mathbf{W}_{\textrm{hq}} \in \mathbb{R}^{h \times q}\) は重みパラメータ、\(\mathbf{b}_\textrm{q} \in \mathbb{R}^{1 \times q}\) は出力層のバイアスパラメータである。分類問題であれば、\(\mathrm{softmax}(\mathbf{O})\) を用いて出力カテゴリの確率分布を計算できる。

9.1 章 で以前に解いた回帰問題と完全に同様なので、詳細は省略する。 要するに、特徴とラベルのペアをランダムに取り出し、自動微分と確率的勾配降下法によってネットワークのパラメータを学習できるということである。

9.4.2. 隠れ状態を持つ再帰型ニューラルネットワーク

隠れ状態がある場合は、事情がまったく異なる。構造をもう少し詳しく見てみよう。

時刻 \(t\) において、ミニバッチの入力 \(\mathbf{X}_t \in \mathbb{R}^{n \times d}\) があると仮定する。 言い換えると、 \(n\) 個の系列例からなるミニバッチについて、 \(\mathbf{X}_t\) の各行は系列中の時刻 \(t\) における 1 つの例に対応する。 次に、 時刻 \(t\) の隠れ層出力を \(\mathbf{H}_t \in \mathbb{R}^{n \times h}\) と表す。 MLP とは異なり、ここでは前の時刻の隠れ層出力 \(\mathbf{H}_{t-1}\) を保持し、前の時刻の隠れ層出力を現在の時刻でどのように使うかを表す新しい重みパラメータ \(\mathbf{W}_{\textrm{hh}} \in \mathbb{R}^{h \times h}\) を導入する。具体的には、現在の時刻の隠れ層出力の計算は、現在の時刻の入力と前の時刻の隠れ層出力の両方によって決まる。

(9.4.5)\[\mathbf{H}_t = \phi(\mathbf{X}_t \mathbf{W}_{\textrm{xh}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hh}} + \mathbf{b}_\textrm{h}).\]

(9.4.3) と比べると、(9.4.5) では項 \(\mathbf{H}_{t-1} \mathbf{W}_{\textrm{hh}}\) が 1 つ追加されており、そのため (9.4.2) を具体化している。 隣接する時刻の隠れ層出力 \(\mathbf{H}_t\)\(\mathbf{H}_{t-1}\) の関係から、 これらの変数が、現在の時刻までの系列の履歴情報を保持していることがわかる。、ニューラルネットワークの現在の時刻における状態、あるいは記憶のようなものである。したがって、このような隠れ層出力は 隠れ状態 と呼ばれる。 隠れ状態は、前の時刻の定義を現在の時刻でそのまま用いるため、(9.4.5) の計算は 再帰的 である。したがって、先ほど述べたように、再帰的計算に基づく隠れ状態を持つニューラルネットワークは 再帰型ニューラルネットワーク と呼ばれる。 RNN において (9.4.5) の計算を行う層は、再帰層 と呼ばれる。

RNN の構成方法にはさまざまなものがある。 (9.4.5) で定義される隠れ状態を持つものは非常に一般的である。 時刻 \(t\) に対して、 出力層の出力は MLP の計算と同様である。

(9.4.6)\[\mathbf{O}_t = \mathbf{H}_t \mathbf{W}_{\textrm{hq}} + \mathbf{b}_\textrm{q}.\]

RNN のパラメータには、隠れ層の重み \(\mathbf{W}_{\textrm{xh}} \in \mathbb{R}^{d \times h}, \mathbf{W}_{\textrm{hh}} \in \mathbb{R}^{h \times h}\)、 バイアス \(\mathbf{b}_\textrm{h} \in \mathbb{R}^{1 \times h}\)、 および出力層の重み \(\mathbf{W}_{\textrm{hq}} \in \mathbb{R}^{h \times q}\) とバイアス \(\mathbf{b}_\textrm{q} \in \mathbb{R}^{1 \times q}\) が含まれる。 特筆すべき点として、 異なる時刻であっても、 RNN は常にこれらのモデルパラメータを使用する。 したがって、RNN のパラメータ化コストは 時刻数が増えても増加しない。

図 9.4.1 は、3 つの隣接する時刻における RNN の計算ロジックを示している。 任意の時刻 \(t\) において、 隠れ状態の計算は次のように扱える。 (i) 現在の時刻 \(t\) の入力 \(\mathbf{X}_t\) と前の時刻 \(t-1\) の隠れ状態 \(\mathbf{H}_{t-1}\) を連結する。 (ii) その連結結果を活性化関数 \(\phi\) を持つ全結合層に入力する。 このような全結合層の出力が、現在の時刻 \(t\) の隠れ状態 \(\mathbf{H}_t\) である。 この場合、 モデルパラメータは (9.4.5) にある \(\mathbf{W}_{\textrm{xh}}\)\(\mathbf{W}_{\textrm{hh}}\) の連結、およびバイアス \(\mathbf{b}_\textrm{h}\) である。 現在の時刻 \(t\) の隠れ状態 \(\mathbf{H}_t\) は、次の時刻 \(t+1\) の隠れ状態 \(\mathbf{H}_{t+1}\) の計算に関与する。 さらに、 \(\mathbf{H}_t\) は全結合の出力層にも入力され、 現在の時刻 \(t\) の出力 \(\mathbf{O}_t\) を計算する。

../_images/rnn.svg

図 9.4.1 隠れ状態を持つ RNN。

先ほど、隠れ状態のための \(\mathbf{X}_t \mathbf{W}_{\textrm{xh}} + \mathbf{H}_{t-1} \mathbf{W}_{\textrm{hh}}\) の計算は、 \(\mathbf{X}_t\)\(\mathbf{H}_{t-1}\) の連結と、 \(\mathbf{W}_{\textrm{xh}}\)\(\mathbf{W}_{\textrm{hh}}\) の連結との 行列積に等しいと述べた。 数学的に証明できるが、 以下では簡単なコード片で示すだけにする。 まず、 形状がそれぞれ (3, 1), (1, 4), (3, 4), (4, 4) の行列 X, W_xh, H, W_hh を定義する。 XW_xh を掛け、HW_hh を掛け、その 2 つの積を加えると、 形状 (3, 4) の行列が得られる。

X, W_xh = d2l.randn(3, 1), d2l.randn(1, 4)
H, W_hh = d2l.randn(3, 4), d2l.randn(4, 4)
d2l.matmul(X, W_xh) + d2l.matmul(H, W_hh)
tensor([[ 1.2834,  3.5848,  0.8610, -2.7687],
        [-1.7809, -3.3326,  0.1534,  3.7534],
        [-2.6922, -2.2836,  1.8133,  0.4768]])
X, W_xh = d2l.randn(3, 1), d2l.randn(1, 4)
H, W_hh = d2l.randn(3, 4), d2l.randn(4, 4)
d2l.matmul(X, W_xh) + d2l.matmul(H, W_hh)
[07:05:29] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
array([[-0.21952915,  4.256434  ,  4.5812645 , -5.344988  ],
       [ 3.447858  , -3.0177274 , -1.6777471 ,  7.535347  ],
       [ 2.2390068 ,  1.4199957 ,  4.744728  , -8.421293  ]])
X, W_xh = jax.random.normal(d2l.get_key(), (3, 1)), jax.random.normal(
                                                        d2l.get_key(), (1, 4))
H, W_hh = jax.random.normal(d2l.get_key(), (3, 4)), jax.random.normal(
                                                        d2l.get_key(), (4, 4))
d2l.matmul(X, W_xh) + d2l.matmul(H, W_hh)
Array([[ 2.4425626 , -0.9535016 ,  0.67076457,  2.189545  ],
       [ 0.61951065, -2.9903038 , -1.23979   , -0.6381068 ],
       [ 1.8468229 , -1.243316  ,  0.359734  , -0.30748147]],      dtype=float32)
X, W_xh = d2l.normal((3, 1)), d2l.normal((1, 4))
H, W_hh = d2l.normal((3, 4)), d2l.normal((4, 4))
d2l.matmul(X, W_xh) + d2l.matmul(H, W_hh)
<tf.Tensor: shape=(3, 4), dtype=float32, numpy=
array([[-1.0680711e-01, -7.6638007e-01,  9.1176510e-02, -1.7421143e+00],
       [-2.1555585e-01, -1.0150743e+00, -6.2941134e-02,  7.0973635e-01],
       [ 3.7314057e-01, -2.1135561e-02,  1.1782646e-03, -8.8488179e-01]],
      dtype=float32)>

次に、行方向(axis 1)に沿って行列 XH を連結し、 列方向(axis 0)に沿って行列 W_xhW_hh を連結する。 これら 2 つの連結結果は、 それぞれ形状 (3, 5) と (5, 4) の行列になる。 この 2 つの連結した行列を掛け合わせると、 上と同じ形状 (3, 4) の出力行列が得られる。

d2l.matmul(d2l.concat((X, H), 1), d2l.concat((W_xh, W_hh), 0))
tensor([[ 1.2834,  3.5848,  0.8610, -2.7687],
        [-1.7809, -3.3326,  0.1534,  3.7534],
        [-2.6922, -2.2836,  1.8133,  0.4768]])

9.4.3. RNN に基づく文字レベル言語モデル

9.3 章 での言語モデリングでは、 現在および過去のトークンに基づいて 次のトークンを予測することを目指した。 そのため、元の系列を 1 トークンずらしたものを目標(ラベル)として用いる。 Bengio et al. (2003) は、言語モデリングにニューラルネットワークを用いることを最初に提案した。 以下では、RNN を用いて言語モデルを構築する方法を示す。 ミニバッチサイズを 1 とし、テキスト系列を “machine” とする。 後続の節での学習を簡単にするため、 テキストを単語ではなく文字にトークン化し、 文字レベル言語モデル を考える。 図 9.4.2 は、文字レベル言語モデリングのために RNN を通じて現在および過去の文字に基づいて次の文字を予測する方法を示している。

../_images/rnn-train.svg

図 9.4.2 RNN に基づく文字レベル言語モデル。入力系列と目標系列はそれぞれ “machin” と “achine” である。

学習過程では、 各時刻の出力層からの出力に対して softmax 演算を行い、その後、交差エントロピー損失を用いてモデル出力と目標との誤差を計算する。 隠れ層における隠れ状態の再帰的計算のため、 図 9.4.2 の時刻 3 における出力 \(\mathbf{O}_3\) は、テキスト系列 “m”, “a”, “c” によって決まる。系列の次の文字は学習データ中では “h” なので、時刻 3 の損失は、特徴系列 “m”, “a”, “c” に基づいて生成された次の文字の確率分布と、この時刻の目標 “h” に依存する。

実際には、各トークンは \(d\) 次元ベクトルで表され、バッチサイズ \(n>1\) を用いる。したがって、時刻 \(t\) における入力 \(\mathbf X_t\)\(n\times d\) 行列となり、 9.4.2 章 で説明した内容と同じである。

以下の節では、文字レベル言語モデルのための RNN を実装する。

9.4.4. まとめ

隠れ状態に対して再帰的計算を用いるニューラルネットワークを、再帰型ニューラルネットワーク(RNN)と呼ぶ。 RNN の隠れ状態は、現在の時刻までの系列の履歴情報を捉えることができる。再帰的計算を用いることで、RNN のモデルパラメータ数は時刻数が増えても増加しない。応用として、RNN は文字レベル言語モデルの構築に利用できる。

9.4.5. 演習

  1. RNN を用いてテキスト系列中の次の文字を予測する場合、任意の出力に必要な次元はどれくらいか。

  2. なぜ RNN は、テキスト系列中のある時刻におけるトークンの条件付き確率を、それ以前のすべてのトークンに基づいて表現できるのか。

  3. 長い系列を逆伝播すると、勾配はどうなるか。

  4. この節で説明した言語モデルに関連する問題にはどのようなものがあるか。