.. _sec_optimization-intro: 最適化と深層学習 ================ この節では、最適化と深層学習の関係、および深層学習における最適化の利用に伴う課題について議論する。 深層学習の問題では、通常まず *損失関数* を定義する。損失関数が得られれば、その損失を最小化しようとして最適化アルゴリズムを用いることができる。 最適化では、損失関数はしばしば最適化問題の *目的関数* と呼ばれる。慣例として、ほとんどの最適化アルゴリズムは *最小化* を扱う。もし目的関数を最大化したい場合は、簡単な解決策がある。目的関数の符号を反転させればよい。 最適化の目標 ------------ 最適化は深層学習における損失関数を最小化する手段を与えてくれるが、本質的には、最適化と深層学習の目標は根本的に異なる。 前者は主として目的関数を最小化することに関心があるのに対し、後者は有限個のデータが与えられたときに適切なモデルを見つけることに関心がある。 :numref:`sec_generalization_basics` では、この2つの目標の違いを詳しく議論した。 たとえば、訓練誤差と汎化誤差は一般に異なる。最適化アルゴリズムの目的関数は通常、訓練データセットに基づく損失関数であるため、最適化の目標は訓練誤差を減らすことである。 しかし、深層学習(より広くは統計的推論)の目標は汎化誤差を減らすことである。 後者を達成するには、最適化アルゴリズムを用いて訓練誤差を減らすことに加えて、過学習にも注意を払う必要がある。 .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline from d2l import torch as d2l import numpy as np from mpl_toolkits import mplot3d import torch .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline from d2l import mxnet as d2l from mpl_toolkits import mplot3d from mxnet import np, npx npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def f(x): return x * d2l.cos(np.pi * x) def g(x): return f(x) + 0.2 * d2l.cos(5 * np.pi * x) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline from d2l import tensorflow as d2l import numpy as np from mpl_toolkits import mplot3d import tensorflow as tf .. raw:: html
.. raw:: html
上述の異なる目標を説明するために、 経験リスクとリスクを考えてみよう。 :numref:`subsec_empirical-risk-and-risk` で述べたように、経験リスクは訓練データセット上の平均損失であり、一方リスクはデータ全体の母集団に対する期待損失である。 以下では2つの関数を定義する。 リスク関数 ``f`` と経験リスク関数 ``g`` である。 訓練データが有限個しかないと仮定しよう。 その結果、ここでの ``g`` は ``f`` よりも滑らかではない。 .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def f(x): return x * d2l.cos(np.pi * x) def g(x): return f(x) + 0.2 * d2l.cos(5 * np.pi * x) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def f(x): return x * d2l.cos(np.pi * x) def g(x): return f(x) + 0.2 * d2l.cos(5 * np.pi * x) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def annotate(text, xy, xytext): #@save d2l.plt.gca().annotate(text, xy=xy, xytext=xytext, arrowprops=dict(arrowstyle='->')) x = d2l.arange(0.5, 1.5, 0.01) d2l.set_figsize((4.5, 2.5)) d2l.plot(x, [f(x), g(x)], 'x', 'risk') annotate('min of\nempirical risk', (1.0, -1.2), (0.5, -1.1)) annotate('min of risk', (1.1, -1.05), (0.95, -0.5)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def f(x): return x * d2l.cos(np.pi * x) def g(x): return f(x) + 0.2 * d2l.cos(5 * np.pi * x) .. raw:: html
.. raw:: html
以下のグラフは、訓練データセット上の経験リスクの最小値が、リスク(汎化誤差)の最小値とは異なる位置にある場合があることを示している。 .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def annotate(text, xy, xytext): #@save d2l.plt.gca().annotate(text, xy=xy, xytext=xytext, arrowprops=dict(arrowstyle='->')) x = d2l.arange(0.5, 1.5, 0.01) d2l.set_figsize((4.5, 2.5)) d2l.plot(x, [f(x), g(x)], 'x', 'risk') annotate('min of\nempirical risk', (1.0, -1.2), (0.5, -1.1)) annotate('min of risk', (1.1, -1.05), (0.95, -0.5)) .. figure:: output_optimization-intro_111d97_33_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def annotate(text, xy, xytext): #@save d2l.plt.gca().annotate(text, xy=xy, xytext=xytext, arrowprops=dict(arrowstyle='->')) x = d2l.arange(0.5, 1.5, 0.01) d2l.set_figsize((4.5, 2.5)) d2l.plot(x, [f(x), g(x)], 'x', 'risk') annotate('min of\nempirical risk', (1.0, -1.2), (0.5, -1.1)) annotate('min of risk', (1.1, -1.05), (0.95, -0.5)) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [07:06:10] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU .. figure:: output_optimization-intro_111d97_36_1.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-1.0, 2.0, 0.01) d2l.plot(x, [f(x), ], 'x', 'f(x)') annotate('local minimum', (-0.3, -0.25), (-0.77, -1.0)) annotate('global minimum', (1.1, -0.95), (0.6, 0.8)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python def annotate(text, xy, xytext): #@save d2l.plt.gca().annotate(text, xy=xy, xytext=xytext, arrowprops=dict(arrowstyle='->')) x = d2l.arange(0.5, 1.5, 0.01) d2l.set_figsize((4.5, 2.5)) d2l.plot(x, [f(x), g(x)], 'x', 'risk') annotate('min of\nempirical risk', (1.0, -1.2), (0.5, -1.1)) annotate('min of risk', (1.1, -1.05), (0.95, -0.5)) .. figure:: output_optimization-intro_111d97_42_0.svg .. raw:: html
.. raw:: html
深層学習における最適化の課題 ---------------------------- この章では、モデルの汎化誤差ではなく、目的関数を最小化する際の最適化アルゴリズムの性能に特に焦点を当てる。 :numref:`sec_linear_regression` では、最適化問題における解析解と数値解を区別した。 深層学習では、ほとんどの目的関数は複雑で、解析解を持たない。その代わりに、数値最適化アルゴリズムを使わなければならない。 この章で扱う最適化アルゴリズムはすべて、この カテゴリに属する。 深層学習の最適化には多くの課題がある。特に厄介なのは、局所最小値、鞍点、そして勾配消失である。 それらを見ていこう。 局所最小値 ~~~~~~~~~~ 任意の目的関数 :math:`f(x)` について、 ある点 :math:`x` における :math:`f(x)` の値が、その近傍の他の任意の点における :math:`f(x)` の値よりも小さいなら、\ :math:`f(x)` は局所最小値である可能性がある。 ある点 :math:`x` における :math:`f(x)` の値が、定義域全体にわたる目的関数の最小値であるなら、\ :math:`f(x)` は大域最小値である。 たとえば、次の関数が与えられたとする。 .. math:: f(x) = x \cdot \textrm{cos}(\pi x) \textrm{ for } -1.0 \leq x \leq 2.0, この関数の局所最小値と大域最小値を近似できる。 .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-1.0, 2.0, 0.01) d2l.plot(x, [f(x), ], 'x', 'f(x)') annotate('local minimum', (-0.3, -0.25), (-0.77, -1.0)) annotate('global minimum', (1.1, -0.95), (0.6, 0.8)) .. figure:: output_optimization-intro_111d97_48_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-1.0, 2.0, 0.01) d2l.plot(x, [f(x), ], 'x', 'f(x)') annotate('local minimum', (-0.3, -0.25), (-0.77, -1.0)) annotate('global minimum', (1.1, -0.95), (0.6, 0.8)) .. figure:: output_optimization-intro_111d97_51_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-2.0, 2.0, 0.01) d2l.plot(x, [x**3], 'x', 'f(x)') annotate('saddle point', (0, -0.2), (-0.52, -5.0)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-1.0, 2.0, 0.01) d2l.plot(x, [f(x), ], 'x', 'f(x)') annotate('local minimum', (-0.3, -0.25), (-0.77, -1.0)) annotate('global minimum', (1.1, -0.95), (0.6, 0.8)) .. figure:: output_optimization-intro_111d97_57_0.svg .. raw:: html
.. raw:: html
深層学習モデルの目的関数は通常、多くの局所最適解を持つ。 最適化問題の数値解が局所最適解の近くにあるとき、最終反復で得られる数値解は、目的関数の勾配が0に近づくか0になるため、目的関数を *大域的に* ではなく *局所的に* しか最小化しないかもしれない。 ある程度のノイズがなければ、パラメータを局所最小値から押し出せないことがある。実際、これはミニバッチ確率的勾配降下法の有益な性質の1つであり、ミニバッチ間で勾配が自然に変動することで、パラメータを局所最小値から引き離すことができる。 鞍点 ~~~~ 局所最小値に加えて、鞍点も勾配が消失する別の原因である。\ *鞍点* とは、関数のすべての勾配が0になるものの、大域最小値でも局所最小値でもない位置のことである。 関数 :math:`f(x) = x^3` を考えよう。その1階導関数と2階導関数は :math:`x=0` で0になる。最適化は、この点で停滞する可能性があるが、ここは最小値ではない。 .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-2.0, 2.0, 0.01) d2l.plot(x, [x**3], 'x', 'f(x)') annotate('saddle point', (0, -0.2), (-0.52, -5.0)) .. figure:: output_optimization-intro_111d97_63_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-2.0, 2.0, 0.01) d2l.plot(x, [x**3], 'x', 'f(x)') annotate('saddle point', (0, -0.2), (-0.52, -5.0)) .. figure:: output_optimization-intro_111d97_66_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-2.0, 5.0, 0.01) d2l.plot(x, [d2l.tanh(x)], 'x', 'f(x)') annotate('vanishing gradient', (4, 1), (2, 0.0)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-2.0, 2.0, 0.01) d2l.plot(x, [x**3], 'x', 'f(x)') annotate('saddle point', (0, -0.2), (-0.52, -5.0)) .. figure:: output_optimization-intro_111d97_72_0.svg .. raw:: html
.. raw:: html
高次元では鞍点はさらに厄介である。以下の例を見てみよう。関数 :math:`f(x, y) = x^2 - y^2` を考える。この関数の鞍点は :math:`(0, 0)` にある。これは :math:`y` に関しては最大値であり、\ :math:`x` に関しては最小値である。さらに、実際に *鞍* のように見える。これが、この数学的性質の名前の由来である。 .. raw:: html
pytorchmxnettensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x, y = d2l.meshgrid( d2l.linspace(-1.0, 1.0, 101), d2l.linspace(-1.0, 1.0, 101)) z = x**2 - y**2 ax = d2l.plt.figure().add_subplot(111, projection='3d') ax.plot_wireframe(x, y, z, **{'rstride': 10, 'cstride': 10}) ax.plot([0], [0], [0], 'rx') ticks = [-1, 0, 1] d2l.plt.xticks(ticks) d2l.plt.yticks(ticks) ax.set_zticks(ticks) d2l.plt.xlabel('x') d2l.plt.ylabel('y'); .. figure:: output_optimization-intro_111d97_78_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x, y = d2l.meshgrid( d2l.linspace(-1.0, 1.0, 101), d2l.linspace(-1.0, 1.0, 101)) z = x**2 - y**2 ax = d2l.plt.figure().add_subplot(111, projection='3d') ax.plot_wireframe(x.asnumpy(), y.asnumpy(), z.asnumpy(), **{'rstride': 10, 'cstride': 10}) ax.plot([0], [0], [0], 'rx') ticks = [-1, 0, 1] d2l.plt.xticks(ticks) d2l.plt.yticks(ticks) ax.set_zticks(ticks) d2l.plt.xlabel('x') d2l.plt.ylabel('y'); .. figure:: output_optimization-intro_111d97_81_0.svg .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python x, y = d2l.meshgrid( d2l.linspace(-1.0, 1.0, 101), d2l.linspace(-1.0, 1.0, 101)) z = x**2 - y**2 ax = d2l.plt.figure().add_subplot(111, projection='3d') ax.plot_wireframe(x, y, z, **{'rstride': 10, 'cstride': 10}) ax.plot([0], [0], [0], 'rx') ticks = [-1, 0, 1] d2l.plt.xticks(ticks) d2l.plt.yticks(ticks) ax.set_zticks(ticks) d2l.plt.xlabel('x') d2l.plt.ylabel('y'); .. figure:: output_optimization-intro_111d97_84_0.svg .. raw:: html
.. raw:: html
関数の入力が :math:`k` 次元ベクトルで出力がスカラーであると仮定すると、そのヘッセ行列は :math:`k` 個の固有値を持つ。 関数の解は、関数の勾配が0である位置において、局所最小値、局所最大値、または鞍点になりえる。 - 関数のヘッセ行列の固有値が、勾配が0の位置で全て正なら、その関数は局所最小値を持つ。 - 関数のヘッセ行列の固有値が、勾配が0の位置で全て負なら、その関数は局所最大値を持つ。 - 関数のヘッセ行列の固有値が、勾配が0の位置で負と正の両方を含むなら、その関数は鞍点を持つ。 高次元問題では、少なくとも *いくつか* の固有値が負である確率はかなり高くなる。そのため、鞍点は局所最小値よりも起こりやすくなる。次の節で凸性を導入するときに、この状況に対するいくつかの例外を議論する。要するに、凸関数とはヘッセ行列の固有値が決して負にならない関数である。残念ながら、深層学習の問題の大半はこのカテゴリに入らない。それでも、最適化アルゴリズムを学ぶうえで非常に有用な道具である。 勾配消失 ~~~~~~~~ おそらく最も厄介な問題は勾配消失である。 :numref:`subsec_activation-functions` で扱った、よく使われる活性化関数とその導関数を思い出されたい。 たとえば、関数 :math:`f(x) = \tanh(x)` を最小化したいとして、たまたま :math:`x = 4` から始めたとする。見てのとおり、\ :math:`f` の勾配はほとんど0である。 より具体的には、\ :math:`f'(x) = 1 - \tanh^2(x)` なので、\ :math:`f'(4) = 0.0013` である。 その結果、進展が見られるまで最適化は長い間停滞する。これは、ReLU活性化関数が導入される以前に深層学習モデルの訓練がかなり難しかった理由の1つである。 .. raw:: latex \diilbookstyleinputcell .. code:: python x = d2l.arange(-2.0, 5.0, 0.01) d2l.plot(x, [d2l.tanh(x)], 'x', 'f(x)') annotate('vanishing gradient', (4, 1), (2, 0.0)) .. figure:: output_optimization-intro_111d97_88_0.svg 見てきたように、深層学習の最適化には課題が山積している。幸いなことに、性能が良く、初心者でも使いやすい堅牢なアルゴリズムが幅広く存在する。さらに、必ずしも *唯一の* 最良解を見つける必要はない。局所最適解や、その近似解であっても、依然として非常に有用である。 まとめ ------ - 訓練誤差を最小化しても、汎化誤差を最小化する最良のパラメータ集合が見つかるとは限りません。 - 最適化問題には多くの局所最小値が存在しうる。 - 一般に問題は凸ではないため、さらに多くの鞍点が存在しうる。 - 勾配消失は最適化を停滞させることがある。しばしば、問題の再パラメータ化が役立つ。パラメータの良い初期化も有益である。 演習 ---- 1. たとえば隠れ層の次元が :math:`d` で出力が1つの、単純な1隠れ層MLPを考えよ。任意の局所最小値に対して、同一に振る舞う等価な解が少なくとも :math:`d!` 個あることを示せ。 2. 対称なランダム行列 :math:`\mathbf{M}` があり、その要素 :math:`M_{ij} = M_{ji}` はそれぞれある確率分布 :math:`p_{ij}` から生成されると仮定する。さらに :math:`p_{ij}(x) = p_{ij}(-x)`\ 、すなわち分布が対称であると仮定する(詳細は例えば :cite:t:`Wigner.1958` を参照)。 1. 固有値の分布も対称であることを証明せよ。すなわち、任意の固有ベクトル :math:`\mathbf{v}` に対して、対応する固有値 :math:`\lambda` が :math:`P(\lambda > 0) = P(\lambda < 0)` を満たす確率を示せ。 2. なぜ上の結果から *P(:raw-latex:`\lambda `> 0) = 0.5* は導けないのか。 3. 深層学習の最適化に関わる他の課題として、どのようなものが考えられるか。 4. (実際の)ボールを(実際の)鞍の上で釣り合わせたいと仮定する。 1. なぜこれは難しいのか。 2. この効果を最適化アルゴリズムにも利用できるか。