9.7. 時間を通した逆伝播

9.5 章 の演習を終えていれば、 勾配クリッピングが、まれに発生する巨大な勾配が 学習を不安定にするのを防ぐうえで不可欠であることを 見てきたはずである。 爆発する勾配は、長い系列にわたって逆伝播することに 起因すると示唆した。 現代的なRNNアーキテクチャを多数紹介する前に、 逆伝播 が系列モデルで数学的にどのように 機能するのかを、もう少し詳しく見てみよう。 この議論によって、消失 勾配と 爆発 勾配という概念に いくらか精密さが加わることを期待している。 5.3 章 でMLPを導入したときに、 計算グラフを通した順伝播と逆伝播についての議論を 思い出していただければ、RNNにおける順伝播は 比較的わかりやすいはずである。 RNNに逆伝播を適用することを 時間を通した逆伝播 と呼ぶ (Werbos, 1990)。 この手続きでは、RNNの計算グラフを 1タイムステップずつ展開(またはアンロール)する必要がある。 アンロールされたRNNは本質的にフィードフォワード型の ニューラルネットワークであり、 同じパラメータがアンロールされたネットワーク全体で 繰り返し現れ、各タイムステップに登場するという 特別な性質を持っている。 その後は、通常のフィードフォワード型ニューラルネットワークと同様に、 連鎖律を適用して、アンロールされたネットワークを通して 勾配を逆伝播できる。 各パラメータに関する勾配は、 アンロールされたネットワーク内でそのパラメータが現れる すべての箇所にわたって和を取らなければならない。 このような重み共有の扱いは、 畳み込みニューラルネットワークの章で すでに見てきたはずである。

系列はかなり長くなりうるため、問題が生じる。 1000個を超えるトークンからなるテキスト系列を扱うことは 珍しくない。 これは、計算量の観点(メモリが多すぎる)と 最適化の観点(数値的不安定性)の両方で 問題を引き起こす。 最初のステップからの入力は、出力に到達するまでに 1000回以上の行列積を通過し、 勾配を計算するためにもさらに1000回の行列積が必要である。 ここでは、何がうまくいかなくなるのか、 そして実際にはどう対処するのかを分析する。

9.7.1. RNNにおける勾配の解析

まず、RNNがどのように動作するかについての 単純化したモデルから始める。 このモデルでは、隠れ状態の具体的な詳細や その更新方法は無視する。 ここでの数学記法では、 スカラー、ベクトル、行列を明示的に区別しない。 あくまで直感を養うことが目的である。 この単純化したモデルでは、 時刻 \(t\) における隠れ状態を \(h_t\)、 入力を \(x_t\)、出力を \(o_t\) と表す。 9.4.2 章 で議論したように、 入力と隠れ状態は、隠れ層の1つの重み変数で 掛け算される前に連結できる。 したがって、隠れ層と出力層の重みをそれぞれ \(w_\textrm{h}\)\(w_\textrm{o}\) で表す。 その結果、各タイムステップにおける隠れ状態と出力は

(9.7.1)\[\begin{split}\begin{aligned}h_t &= f(x_t, h_{t-1}, w_\textrm{h}),\\o_t &= g(h_t, w_\textrm{o}),\end{aligned}\end{split}\]

となる。ここで \(f\)\(g\) はそれぞれ 隠れ層と出力層の変換である。 したがって、相互に依存する \(\{\ldots, (x_{t-1}, h_{t-1}, o_{t-1}), (x_{t}, h_{t}, o_t), \ldots\}\) という値の連鎖が、再帰的な計算を通じて得られる。 順伝播はかなり単純である。 必要なのは、\((x_t, h_t, o_t)\) の組を 1タイムステップずつ順にループすることだけである。 その後、出力 \(o_t\) と望ましい目標 \(y_t\) の差は、 全 \(T\) タイムステップにわたる目的関数で評価され、

(9.7.2)\[L(x_1, \ldots, x_T, y_1, \ldots, y_T, w_\textrm{h}, w_\textrm{o}) = \frac{1}{T}\sum_{t=1}^T l(y_t, o_t).\]

逆伝播では、特に目的関数 \(L\) のパラメータ \(w_\textrm{h}\) に関する 勾配を計算するときに、少し厄介になる。 具体的には、連鎖律により、

(9.7.3)\[\begin{split}\begin{aligned}\frac{\partial L}{\partial w_\textrm{h}} & = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial w_\textrm{h}} \\& = \frac{1}{T}\sum_{t=1}^T \frac{\partial l(y_t, o_t)}{\partial o_t} \frac{\partial g(h_t, w_\textrm{o})}{\partial h_t} \frac{\partial h_t}{\partial w_\textrm{h}}.\end{aligned}\end{split}\]

(9.7.3) における積の 最初と2番目の因子は簡単に計算できる。 3番目の因子 \(\partial h_t/\partial w_\textrm{h}\) が難所であり、 パラメータ \(w_\textrm{h}\)\(h_t\) に与える影響を 再帰的に計算する必要がある。 (9.7.1) の再帰的計算によれば、 \(h_t\)\(h_{t-1}\)\(w_\textrm{h}\) の両方に依存し、 \(h_{t-1}\) の計算もまた \(w_\textrm{h}\) に依存する。 したがって、連鎖律を用いて \(h_t\)\(w_\textrm{h}\) に関する 全微分を評価すると、

(9.7.4)\[\frac{\partial h_t}{\partial w_\textrm{h}}= \frac{\partial f(x_{t},h_{t-1},w_\textrm{h})}{\partial w_\textrm{h}} +\frac{\partial f(x_{t},h_{t-1},w_\textrm{h})}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_\textrm{h}}.\]

上の勾配を導くために、3つの系列 \(\{a_{t}\},\{b_{t}\},\{c_{t}\}\)\(a_{0}=0\) かつ \(t=1, 2,\ldots\) に対して \(a_{t}=b_{t}+c_{t}a_{t-1}\) を満たすと仮定する。 すると \(t\geq 1\) について、次が容易に示せる。

(9.7.5)\[a_{t}=b_{t}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t}c_{j}\right)b_{i}.\]

\(a_t\)\(b_t\)\(c_t\) をそれぞれ

(9.7.6)\[\begin{split}\begin{aligned}a_t &= \frac{\partial h_t}{\partial w_\textrm{h}},\\ b_t &= \frac{\partial f(x_{t},h_{t-1},w_\textrm{h})}{\partial w_\textrm{h}}, \\ c_t &= \frac{\partial f(x_{t},h_{t-1},w_\textrm{h})}{\partial h_{t-1}},\end{aligned}\end{split}\]

に従って置き換えると、(9.7.4) の勾配計算は \(a_{t}=b_{t}+c_{t}a_{t-1}\) を満たす。 したがって、(9.7.5) により、 (9.7.4) における再帰計算を

(9.7.7)\[\frac{\partial h_t}{\partial w_\textrm{h}}=\frac{\partial f(x_{t},h_{t-1},w_\textrm{h})}{\partial w_\textrm{h}}+\sum_{i=1}^{t-1}\left(\prod_{j=i+1}^{t} \frac{\partial f(x_{j},h_{j-1},w_\textrm{h})}{\partial h_{j-1}} \right) \frac{\partial f(x_{i},h_{i-1},w_\textrm{h})}{\partial w_\textrm{h}}.\]

で取り除くことができる。

連鎖律を使って \(\partial h_t/\partial w_\textrm{h}\) を再帰的に計算できるが、 \(t\) が大きいとこの連鎖は非常に長くなる。 この問題に対処するいくつかの戦略を議論しよう。

9.7.1.1. 全計算

1つの考え方は、(9.7.7) の 完全な和を計算することかもしれない。 しかし、これは非常に遅く、勾配が爆発する可能性がある。 というのも、初期条件のわずかな変化が 結果に大きく影響しうるからである。 つまり、初期条件の小さな変化が 結果に不釣り合いな変化をもたらす、 いわゆるバタフライ効果に似た現象が 起こりえる。 これは一般に望ましくない。 結局のところ、私たちが求めているのは、 よく一般化する頑健な推定量である。 したがって、この戦略が実際に使われることは ほとんどない。

9.7.1.2. タイムステップの切り詰め

別の方法として、 (9.7.7) の和を \(\tau\) ステップ後で打ち切ることができる。 これが、ここまで議論してきた方法である。 これは、和を \(\partial h_{t-\tau}/\partial w_\textrm{h}\) で 単純に終了させることで、真の勾配の 近似 を与える。 実際には、これはかなりうまく機能する。 これは一般に、切り詰めた 時間を通した逆伝播と呼ばれる (Jaeger, 2002)。 この結果の1つとして、モデルは 長期的な結果よりも短期的な影響に 主として注目するようになる。 これは実際には 望ましい ことであり、 より単純で安定したモデルへと推定を 偏らせるからである。

9.7.1.3. ランダム化された切り詰め

最後に、\(\partial h_t/\partial w_\textrm{h}\) を、 期待値では正しいが系列を切り詰める 確率変数で置き換えることができる。 これは、あらかじめ定めた \(0 \leq \pi_t \leq 1\) を持つ \(\xi_t\) の系列を用いて実現され、 \(P(\xi_t = 0) = 1-\pi_t\) かつ \(P(\xi_t = \pi_t^{-1}) = \pi_t\) となるため、 \(E[\xi_t] = 1\) である。 これを用いて、(9.7.4) における 勾配 \(\partial h_t/\partial w_\textrm{h}\)

(9.7.8)\[z_t= \frac{\partial f(x_{t},h_{t-1},w_\textrm{h})}{\partial w_\textrm{h}} +\xi_t \frac{\partial f(x_{t},h_{t-1},w_\textrm{h})}{\partial h_{t-1}} \frac{\partial h_{t-1}}{\partial w_\textrm{h}}.\]

で置き換える。

\(\xi_t\) の定義から、\(E[z_t] = \partial h_t/\partial w_\textrm{h}\) が 成り立つ。 \(\xi_t = 0\) のときはいつでも、再帰計算は そのタイムステップ \(t\) で終了する。 これにより、長さの異なる系列の重み付き和が得られ、 長い系列はまれだが適切に重み付けされる。 この考え方は Tallec and Ollivier (2017) によって提案された。

9.7.1.4. 戦略の比較

../_images/truncated-bptt.svg

図 9.7.1 RNNにおける勾配計算の戦略の比較。上から順に、ランダム化された切り詰め、通常の切り詰め、全計算。

図 9.7.1 は、RNNに対して時間を通した逆伝播を用いて The Time Machine の最初の数文字を解析するときの 3つの戦略を示している。

  • 1行目は、長さの異なる区間にテキストを分割する ランダム化された切り詰めである。

  • 2行目は、同じ長さの部分系列にテキストを分割する 通常の切り詰めである。これは、RNNの実験で これまで行ってきた方法である。

  • 3行目は、計算上実行不可能な式につながる 全時間逆伝播である。

残念ながら、理論的には魅力的であるものの、 ランダム化された切り詰めは通常の切り詰めよりも 大きく優れているわけではない。 おそらく、いくつかの要因によるものである。 第1に、過去へいくつかの逆伝播ステップを経た後の 観測の影響は、実際には依存関係を捉えるのに 十分である。 第2に、分散の増加は、より多くのステップで 勾配がより正確になるという事実を打ち消す。 第3に、私たちは実際には、相互作用の範囲が 短いモデルを 望んで いる。 したがって、通常の切り詰めを用いた時間を通した逆伝播には、 望ましい正則化効果がわずかにある。

9.7.2. 時間を通した逆伝播の詳細

一般原理を議論したので、 次に時間を通した逆伝播を詳しく見ていこう。 9.7.1 章 の解析とは対照的に、 以下では、分解されたすべてのモデルパラメータに関する 目的関数の勾配をどのように計算するかを示す。 簡単のため、バイアスパラメータを持たず、 隠れ層の活性化関数として恒等写像 (\(\phi(x)=x\))を用いるRNNを考える。 タイムステップ \(t\) において、単一の例の入力と目標を それぞれ \(\mathbf{x}_t \in \mathbb{R}^d\)\(y_t\) とする。 隠れ状態 \(\mathbf{h}_t \in \mathbb{R}^h\) と 出力 \(\mathbf{o}_t \in \mathbb{R}^q\) は次のように計算される。

(9.7.9)\[\begin{split}\begin{aligned}\mathbf{h}_t &= \mathbf{W}_\textrm{hx} \mathbf{x}_t + \mathbf{W}_\textrm{hh} \mathbf{h}_{t-1},\\ \mathbf{o}_t &= \mathbf{W}_\textrm{qh} \mathbf{h}_{t},\end{aligned}\end{split}\]

ここで \(\mathbf{W}_\textrm{hx} \in \mathbb{R}^{h \times d}\)\(\mathbf{W}_\textrm{hh} \in \mathbb{R}^{h \times h}\)\(\mathbf{W}_\textrm{qh} \in \mathbb{R}^{q \times h}\) は 重みパラメータである。 時刻 \(t\) における損失を \(l(\mathbf{o}_t, y_t)\) と表す。 したがって、私たちの目的関数、すなわち系列の先頭から \(T\) タイムステップにわたる損失は

(9.7.10)\[L = \frac{1}{T} \sum_{t=1}^T l(\mathbf{o}_t, y_t).\]

RNNの計算中にモデル変数とパラメータの間の依存関係を 可視化するために、 図 9.7.2 に示すような モデルの計算グラフを描くことができる。 たとえば、3番目のタイムステップの隠れ状態 \(\mathbf{h}_3\) の計算は、モデルパラメータ \(\mathbf{W}_\textrm{hx}\)\(\mathbf{W}_\textrm{hh}\)、 前のタイムステップの隠れ状態 \(\mathbf{h}_2\)、 および現在のタイムステップの入力 \(\mathbf{x}_3\) に依存する。

../_images/rnn-bptt.svg

図 9.7.2 3タイムステップのRNNモデルにおける依存関係を示す計算グラフ。箱は変数(塗りつぶしなし)またはパラメータ(塗りつぶしあり)を表し、円は演算子を表す。

先ほど述べたように、 図 9.7.2 のモデルパラメータは \(\mathbf{W}_\textrm{hx}\)\(\mathbf{W}_\textrm{hh}\)\(\mathbf{W}_\textrm{qh}\) である。 一般に、このモデルの学習には、これらのパラメータに関する 勾配計算 \(\partial L/\partial \mathbf{W}_\textrm{hx}\)\(\partial L/\partial \mathbf{W}_\textrm{hh}\)\(\partial L/\partial \mathbf{W}_\textrm{qh}\) が必要である。 図 9.7.2 の依存関係に従って、 矢印と逆向きにたどることで、順に勾配を計算して保存できる。 連鎖律において、異なる形状を持つ 行列、ベクトル、スカラーの積を柔軟に表現するために、 5.3 章 で説明したように \(\textrm{prod}\) 演算子を引き続き用いる。

まず、任意のタイムステップ \(t\) における モデル出力に関して目的関数を微分するのは かなり単純である。

(9.7.11)\[\frac{\partial L}{\partial \mathbf{o}_t} = \frac{\partial l (\mathbf{o}_t, y_t)}{T \cdot \partial \mathbf{o}_t} \in \mathbb{R}^q.\]

これで、出力層のパラメータ \(\mathbf{W}_\textrm{qh}\) に関する 目的関数の勾配 \(\partial L/\partial \mathbf{W}_\textrm{qh} \in \mathbb{R}^{q \times h}\) を計算できる。 図 9.7.2 に基づけば、 目的関数 \(L\)\(\mathbf{o}_1, \ldots, \mathbf{o}_T\) を通じて \(\mathbf{W}_\textrm{qh}\) に依存する。 連鎖律を用いると、

(9.7.12)\[\frac{\partial L}{\partial \mathbf{W}_\textrm{qh}} = \sum_{t=1}^T \textrm{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{W}_\textrm{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top,\]

ここで \(\partial L/\partial \mathbf{o}_t\)(9.7.11) で与えられる。

次に、 図 9.7.2 に示すように、 最終タイムステップ \(T\) では、目的関数 \(L\) は 出力 \(\mathbf{o}_T\) を通じてのみ隠れ状態 \(\mathbf{h}_T\) に依存する。 したがって、連鎖律を用いて 隠れ状態に関する勾配 \(\partial L/\partial \mathbf{h}_T \in \mathbb{R}^h\) を容易に求められる。

(9.7.13)\[\frac{\partial L}{\partial \mathbf{h}_T} = \textrm{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}_T} \right) = \mathbf{W}_\textrm{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T}.\]

任意のタイムステップ \(t < T\) では事情が少し複雑になる。 このとき、目的関数 \(L\)\(\mathbf{h}_{t+1}\)\(\mathbf{o}_t\) を通じて \(\mathbf{h}_t\) に依存する。 連鎖律に従うと、任意のタイムステップ \(t < T\) における 隠れ状態の勾配 \(\partial L/\partial \mathbf{h}_t \in \mathbb{R}^h\) は 再帰的に次のように計算できる。

(9.7.14)\[\frac{\partial L}{\partial \mathbf{h}_t} = \textrm{prod}\left(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \textrm{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} \right) = \mathbf{W}_\textrm{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_\textrm{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t}.\]

解析のために、任意のタイムステップ \(1 \leq t \leq T\) について 再帰計算を展開すると、

(9.7.15)\[\frac{\partial L}{\partial \mathbf{h}_t}= \sum_{i=t}^T {\left(\mathbf{W}_\textrm{hh}^\top\right)}^{T-i} \mathbf{W}_\textrm{qh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}}.\]

(9.7.15) からわかるように、 この単純な線形例でさえ、長い系列モデルの重要な問題のいくつかを すでに示している。 そこには、\(\mathbf{W}_\textrm{hh}^\top\) の非常に大きな冪が 含まれうるのである。 その固有値が1より小さいと消失し、 1より大きいと発散する。 これは数値的に不安定であり、 消失勾配や爆発勾配として現れる。 これに対処する1つの方法は、 9.7.1 章 で議論したように、 タイムステップを計算上扱いやすい大きさで 切り詰めることである。 実際には、この切り詰めは、 一定数のタイムステップの後で勾配を切り離すことでも 実現できる。 後ほど、長短期記憶(LSTM)のような より洗練された系列モデルが、これをさらに緩和できることを見る。

最後に、 図 9.7.2 は、 目的関数 \(L\) が、隠れ状態 \(\mathbf{h}_1, \ldots, \mathbf{h}_T\) を通じて、 隠れ層のモデルパラメータ \(\mathbf{W}_\textrm{hx}\)\(\mathbf{W}_\textrm{hh}\) に依存することを示している。 このようなパラメータに関する勾配 \(\partial L / \partial \mathbf{W}_\textrm{hx} \in \mathbb{R}^{h \times d}\)\(\partial L / \partial \mathbf{W}_\textrm{hh} \in \mathbb{R}^{h \times h}\) を計算するには、 連鎖律を適用して

(9.7.16)\[\begin{split}\begin{aligned} \frac{\partial L}{\partial \mathbf{W}_\textrm{hx}} &= \sum_{t=1}^T \textrm{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_\textrm{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{x}_t^\top,\\ \frac{\partial L}{\partial \mathbf{W}_\textrm{hh}} &= \sum_{t=1}^T \textrm{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_\textrm{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{h}_{t-1}^\top, \end{aligned}\end{split}\]

ここで、(9.7.13)(9.7.14) によって再帰的に計算される \(\partial L/\partial \mathbf{h}_t\) が、 数値安定性に影響する重要な量である。

時間を通した逆伝播は、RNNにおける逆伝播の適用そのものであるため、 5.3 章 で説明したように、 RNNの学習では順伝播と時間を通した逆伝播を交互に行う。 さらに、時間を通した逆伝播では、 上記の勾配を順に計算して保存する。 具体的には、保存された中間値を再利用して 重複計算を避ける。たとえば、 \(\partial L/\partial \mathbf{h}_t\) を保存しておき、 \(\partial L / \partial \mathbf{W}_\textrm{hx}\)\(\partial L / \partial \mathbf{W}_\textrm{hh}\) の両方の計算に 使う。

9.7.3. まとめ

時間を通した逆伝播は、隠れ状態を持つ系列モデルに 逆伝播を適用したものにすぎない。 通常の切り詰めやランダム化された切り詰めのような 切り詰めは、計算上の都合と数値安定性のために必要である。 行列の高次冪は、発散または消失する固有値を引き起こしえる。 これは、爆発勾配や消失勾配として現れる。 効率的に計算するために、時間を通した逆伝播の間、 中間値はキャッシュされる。

9.7.4. 演習

  1. 対称行列 \(\mathbf{M} \in \mathbb{R}^{n \times n}\) があり、その固有値を \(\lambda_i\)、対応する固有ベクトルを \(\mathbf{v}_i\)\(i = 1, \ldots, n\))とする。一般性を失わずに、\(|\lambda_i| \geq |\lambda_{i+1}|\) の順に並んでいると仮定する。

    1. \(\mathbf{M}^k\) の固有値が \(\lambda_i^k\) であることを示せ。

    2. ランダムなベクトル \(\mathbf{x} \in \mathbb{R}^n\) に対して、高い確率で \(\mathbf{M}^k \mathbf{x}\)\(\mathbf{M}\) の固有ベクトル \(\mathbf{v}_1\) に非常によく整列することを証明せよ。この主張を形式化せよ。

    3. 上の結果はRNNの勾配にとって何を意味するか。

  2. 勾配クリッピング以外に、再帰型ニューラルネットワークにおける勾配爆発に対処する他の方法を思いつくだろうか?