.. _sec_weight_decay: 重み減衰 ======== 過学習の問題を特徴づけたので、最初の\ *正則化*\ 手法を導入できる。 過学習は、より多くの訓練データを集めることで常に緩和できることを思い出そう。 しかし、それにはコストがかかり、時間もかかり、 あるいは完全に私たちの制御外であることもあり、 短期的には不可能である。 ひとまず、私たちはすでに 資源が許す限り十分に高品質なデータを持っていると仮定し、 データセットが与えられたものとして扱うときに利用できる手段に集中しよう。 多項式回帰の例 (:numref:`subsec_polynomial-curve-fitting`) では、当てはめる多項式の次数を調整することで モデルの容量を制限できた。 実際、特徴量の数を制限することは 過学習を抑えるための一般的な手法である。 しかし、単に特徴量を切り捨てるだけでは あまりに大雑把すぎることがある。 多項式回帰の例に戻って、 高次元入力で何が起こりうるかを考えてみよう。 多変量データへの多項式の自然な拡張は *単項式*\ と呼ばれ、変数のべき乗の積にすぎない。 単項式の次数は、べきの和である。 たとえば、\ :math:`x_1^2 x_2` と :math:`x_3 x_5^2` はいずれも次数3の単項式である。 次数 :math:`d` の項の数は、\ :math:`d` が大きくなるにつれて 急速に爆発的に増えることに注意されたい。 :math:`k` 個の変数があるとき、次数 :math:`d` の単項式の数は :math:`{k - 1 + d} \choose {k - 1}` である。 次数を :math:`2` から :math:`3` に変えるといった小さな変化でも、 モデルの複雑さは劇的に増加する。 したがって、関数の複雑さを調整するには、 よりきめ細かな手段がしばしば必要になる。 .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline from d2l import torch as d2l import torch from torch import nn .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline from d2l import mxnet as d2l from mxnet import autograd, gluon, init, np, npx from mxnet.gluon import nn npx.set_np() .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline from d2l import jax as d2l import jax from jax import numpy as jnp import optax .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python %matplotlib inline from d2l import tensorflow as d2l import tensorflow as tf .. raw:: html
.. raw:: html
ノルムと重み減衰 ---------------- パラメータの数を直接操作する代わりに、 *重み減衰*\ は、パラメータが取りうる値を制限することで機能する。 深層学習以外の分野では、ミニバッチ確率的勾配降下法で最適化するとき、 より一般には :math:`\ell_2` 正則化と呼ばれる。 重み減衰は、パラメトリック機械学習モデルを正則化するための 最も広く使われている手法の一つかもしれない。 この手法は、すべての関数 :math:`f` の中で、 関数 :math:`f = 0` (すべての入力に値 :math:`0` を割り当てるもの)が ある意味で\ *最も単純*\ であり、 関数の複雑さはパラメータがゼロからどれだけ離れているかで測れる、 という基本的な直観に動機づけられている。 しかし、関数とゼロの間の距離を 正確にはどのように測ればよいのだろうか? 唯一の正解があるわけではない。 実際、関数解析の一部や バナッハ空間の理論を含む数学の大きな分野全体が、 このような問題に取り組むために存在している。 一つの単純な解釈としては、 線形関数 :math:`f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x}` の複雑さを、その重みベクトルの何らかのノルム、たとえば :math:`\| \mathbf{w} \|^2` で測ることが考えられる。 : numref:``subsec_lin-algebra-norms`` で、 より一般的な :math:`\ell_p` ノルムの特殊な場合である :math:`\ell_2` ノルムと :math:`\ell_1` ノルムを導入したことを思い出そう。 小さな重みベクトルを確保する最も一般的な方法は、 そのノルムを損失最小化問題に罰則項として加えることである。 したがって、元の目的関数、 すなわち\ *訓練ラベルに対する予測損失を最小化すること*\ を、 新しい目的関数、 すなわち\ *予測損失と罰則項の和を最小化すること*\ に置き換える。 こうすると、重みベクトルが大きくなりすぎた場合、 学習アルゴリズムは訓練誤差の最小化よりも 重みノルム :math:`\| \mathbf{w} \|^2` の最小化に 注力するかもしれない。 それこそが私たちの望むことである。 コードで示すために、 : numref:``sec_linear_regression`` の線形回帰の例を再び取り上げる。 そこでは、損失は次のように与えられていた。 .. math:: L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. :math:`\mathbf{x}^{(i)}` は特徴量、 :math:`y^{(i)}` はデータ例 :math:`i` のラベルであり、\ :math:`(\mathbf{w}, b)` はそれぞれ重みパラメータとバイアスパラメータである。 重みベクトルの大きさに罰則を与えるには、 何らかの形で損失関数に :math:`\| \mathbf{w} \|^2` を加える必要があるが、 この新しい加法的な罰則に対して、モデルは標準的な損失を どのようにトレードオフすべきだろうか? 実際には、このトレードオフを *正則化定数* :math:`\lambda` によって特徴づける。 これは非負のハイパーパラメータであり、 検証データを用いて調整する。 .. math:: L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2. :math:`\lambda = 0` なら、元の損失関数に戻りる。 :math:`\lambda > 0` なら、\ :math:`\| \mathbf{w} \|` の大きさを制限する。 :math:`2` で割るのは慣習である。 二次関数の微分を取るとき、 :math:`2` と :math:`1/2` が打ち消し合い、更新式が 見た目にもきれいで簡潔になる。 鋭い読者は、なぜ標準ノルム(すなわちユークリッド距離)ではなく、 二乗したノルムを使うのかと疑問に思うかもしれない。 これは計算上の都合である。 :math:`\ell_2` ノルムを二乗することで平方根が消え、 重みベクトルの各成分の二乗和だけが残る。 これにより、罰則項の微分を簡単に計算できる。 すなわち、和の微分は微分の和に等しいのである。 さらに、そもそもなぜ :math:`\ell_1` ノルムではなく :math:`\ell_2` ノルムを使うのか、と疑問に思うかもしれない。 実際、他の選択肢も有効であり、 統計学では広く使われている。 :math:`\ell_2` 正則化された線形モデルは古典的な *リッジ回帰*\ アルゴリズムを構成するが、 :math:`\ell_1` 正則化された線形回帰は 同様に基本的な統計手法であり、 一般に *ラッソ回帰* として知られている。 :math:`\ell_2` ノルムを使う一つの理由は、 重みベクトルの大きな成分に対して 特に大きな罰則を課すことである。 これにより学習アルゴリズムは、 より多くの特徴量に重みを均等に分配するモデルへと 偏りる。 実際には、これは単一変数の測定誤差に対して より頑健にするかもしれない。 対照的に、\ :math:`\ell_1` 罰則は、 他の重みをゼロにしてしまうことで、 少数の特徴量に重みを集中させるモデルを導きる。 これにより、\ *特徴選択*\ のための有効な手法が得られ、 別の理由から望ましいことがある。 たとえば、モデルが少数の特徴量にしか依存しないなら、 他の(捨てられた)特徴量についてデータを収集、保存、送信する必要が なくなるかもしれない。 :eq:`eq_linreg_batch_update` と同じ記法を用いると、 ミニバッチ確率的勾配降下法による :math:`\ell_2` 正則化回帰の更新は次のようになる。 .. math:: \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} これまでと同様に、推定値が観測値からどれだけずれているかに基づいて :math:`\mathbf{w}` を更新する。 しかし同時に、\ :math:`\mathbf{w}` の大きさをゼロへ向かって縮小する。 そのため、この手法はしばしば「重み減衰」と呼ばれる。 罰則項だけを考えると、 最適化アルゴリズムは学習の各ステップで重みを\ *減衰*\ させる。 特徴選択とは対照的に、重み減衰は関数の複雑さを 連続的に調整する仕組みを与えてくれる。 :math:`\lambda` が小さいほど :math:`\mathbf{w}` への制約は弱くなり、 一方で :math:`\lambda` が大きいほど :math:`\mathbf{w}` はより強く制約される。 対応するバイアス罰則 :math:`b^2` を含めるかどうかは 実装によって異なり、 ニューラルネットワークの層ごとに異なる場合もある。 多くの場合、バイアス項は正則化しない。 さらに、 :math:`\ell_2` 正則化が他の最適化アルゴリズムでは重み減衰と等価でない場合があるとしても、 重みの大きさを縮小することによる正則化という考え方自体は 依然として成り立ちる。 高次元線形回帰 -------------- 重み減衰の利点は、 単純な合成例を通して示すことができる。 まず、以前と同様に データを生成する: .. math:: y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \textrm{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). この合成データセットでは、ラベルは入力の背後にある線形関数によって与えられ、 平均0、標準偏差0.01のガウスノイズによって 汚されている。 説明のために、 問題の次元を :math:`d = 200` に増やし、 20例しかない小さな訓練セットで学習することで、 過学習の影響を顕著にできる。 .. raw:: latex \diilbookstyleinputcell .. code:: python class Data(d2l.DataModule): def __init__(self, num_train, num_val, num_inputs, batch_size): self.save_hyperparameters() n = num_train + num_val if tab.selected('mxnet') or tab.selected('pytorch'): self.X = d2l.randn(n, num_inputs) noise = d2l.randn(n, 1) * 0.01 if tab.selected('tensorflow'): self.X = d2l.normal((n, num_inputs)) noise = d2l.normal((n, 1)) * 0.01 if tab.selected('jax'): self.X = jax.random.normal(jax.random.PRNGKey(0), (n, num_inputs)) noise = jax.random.normal(jax.random.PRNGKey(0), (n, 1)) * 0.01 w, b = d2l.ones((num_inputs, 1)) * 0.01, 0.05 self.y = d2l.matmul(self.X, w) + b + noise def get_dataloader(self, train): i = slice(0, self.num_train) if train else slice(self.num_train, None) return self.get_tensorloader([self.X, self.y], train, i) ゼロからの実装 -------------- では、重み減衰をゼロから実装してみよう。 ミニバッチ確率的勾配降下法が最適化手法なので、 元の損失関数に二乗した :math:`\ell_2` 罰則を加えるだけで十分である。 :math:`\ell_2` ノルム罰則の定義 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ この罰則を実装する最も便利な方法は、 すべての項をその場で二乗してから和を取ることかもしれない。 .. raw:: latex \diilbookstyleinputcell .. code:: python def l2_penalty(w): return d2l.reduce_sum(w**2) / 2 モデルの定義 ~~~~~~~~~~~~ 最終的なモデルでは、 線形回帰と二乗損失は :numref:`sec_linear_scratch` から変わっていないので、 ``d2l.LinearRegressionScratch`` のサブクラスを定義するだけでよいだろう。 ここでの唯一の変更点は、損失に罰則項が含まれることである。 .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class WeightDecayScratch(d2l.LinearRegressionScratch): def __init__(self, num_inputs, lambd, lr, sigma=0.01): super().__init__(num_inputs, lr, sigma) self.save_hyperparameters() def loss(self, y_hat, y): return (super().loss(y_hat, y) + self.lambd * l2_penalty(self.w)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class WeightDecayScratch(d2l.LinearRegressionScratch): def __init__(self, num_inputs, lambd, lr, sigma=0.01): super().__init__(num_inputs, lr, sigma) self.save_hyperparameters() def loss(self, y_hat, y): return (super().loss(y_hat, y) + self.lambd * l2_penalty(self.w)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class WeightDecayScratch(d2l.LinearRegressionScratch): lambd: int = 0 def loss(self, params, X, y, state): return (super().loss(params, X, y, state) + self.lambd * l2_penalty(params['w'])) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class WeightDecayScratch(d2l.LinearRegressionScratch): def __init__(self, num_inputs, lambd, lr, sigma=0.01): super().__init__(num_inputs, lr, sigma) self.save_hyperparameters() def loss(self, y_hat, y): return (super().loss(y_hat, y) + self.lambd * l2_penalty(self.w)) .. raw:: html
.. raw:: html
次のコードは、20例の訓練セットでモデルを学習し、100例の検証セットで評価する。 .. raw:: latex \diilbookstyleinputcell .. code:: python data = Data(num_train=20, num_val=100, num_inputs=200, batch_size=5) trainer = d2l.Trainer(max_epochs=10) def train_scratch(lambd): model = WeightDecayScratch(num_inputs=200, lambd=lambd, lr=0.01) model.board.yscale='log' trainer.fit(model, data) if tab.selected('pytorch', 'mxnet', 'tensorflow'): print('L2 norm of w:', float(l2_penalty(model.w))) if tab.selected('jax'): print('L2 norm of w:', float(l2_penalty(trainer.state.params['w']))) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output [07:08:03] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU 正則化なしでの学習 ~~~~~~~~~~~~~~~~~~ ここでは ``lambd = 0`` としてこのコードを実行し、 重み減衰を無効にする。 訓練誤差は下がる一方で検証誤差は下がらず、 ひどく過学習していることに注意されたい。 これは過学習の典型例である。 .. raw:: latex \diilbookstyleinputcell .. code:: python train_scratch(0) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output L2 norm of w: 0.011301761493086815 .. figure:: output_weight-decay_679df7_37_1.svg 重み減衰の使用 ~~~~~~~~~~~~~~ 以下では、かなり強い重み減衰をかけて実行する。 訓練誤差は増加するが、 検証誤差は減少することに注意されたい。 これはまさに正則化から期待される効果である。 .. raw:: latex \diilbookstyleinputcell .. code:: python train_scratch(3) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output L2 norm of w: 0.0016639942768961191 .. figure:: output_weight-decay_679df7_39_1.svg 簡潔な実装 ---------- 重み減衰はニューラルネットワーク最適化で 至るところに使われているため、深層学習フレームワークでは特に便利に扱える。 最適化アルゴリズム自体に重み減衰を統合し、 任意の損失関数と組み合わせて簡単に使えるようにしている。 さらに、この統合は計算上の利点もあり、 追加の計算オーバーヘッドなしに、 実装上の工夫でアルゴリズムに重み減衰を加えられる。 更新の重み減衰部分は 各パラメータの現在値のみに依存するため、 最適化器はどうせ各パラメータに一度は触れる必要がある。 以下では、最適化器をインスタンス化するときに ``weight_decay`` を通じて重み減衰ハイパーパラメータを直接指定する。 デフォルトでは、PyTorch は 重みとバイアスの両方を同時に減衰させるが、 最適化器を異なるパラメータに対して異なる方針で扱うように設定できる。 ここでは、重みに対してのみ(\ ``net.weight`` パラメータに対してのみ) ``weight_decay`` を設定しているため、 バイアス(\ ``net.bias`` パラメータ)は減衰しない。 .. raw:: html
pytorchmxnetjaxtensorflow
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class WeightDecay(d2l.LinearRegression): def __init__(self, wd, lr): super().__init__(lr) self.save_hyperparameters() self.wd = wd def configure_optimizers(self): return torch.optim.SGD([ {'params': self.net.weight, 'weight_decay': self.wd}, {'params': self.net.bias}], lr=self.lr) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class WeightDecay(d2l.LinearRegression): def __init__(self, wd, lr): super().__init__(lr) self.save_hyperparameters() self.wd = wd def configure_optimizers(self): self.collect_params('.*bias').setattr('wd_mult', 0) return gluon.Trainer(self.collect_params(), 'sgd', {'learning_rate': self.lr, 'wd': self.wd}) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class WeightDecay(d2l.LinearRegression): wd: int = 0 def configure_optimizers(self): # Weight Decay is not available directly within optax.sgd, but # optax allows chaining several transformations together return optax.chain(optax.additive_weight_decay(self.wd), optax.sgd(self.lr)) .. raw:: html
.. raw:: html
.. raw:: latex \diilbookstyleinputcell .. code:: python class WeightDecay(d2l.LinearRegression): def __init__(self, wd, lr): super().__init__(lr) self.save_hyperparameters() self.net = tf.keras.layers.Dense( 1, kernel_regularizer=tf.keras.regularizers.l2(wd), kernel_initializer=tf.keras.initializers.RandomNormal(0, 0.01) ) def loss(self, y_hat, y): return super().loss(y_hat, y) + self.net.losses .. raw:: html
.. raw:: html
プロットは、ゼロから重み減衰を実装したときと似ている。 しかし、この版のほうが高速に動作し、 実装も容易である。 問題が大きくなり、作業がより日常的になるにつれて、 これらの利点はさらに顕著になる。 .. raw:: latex \diilbookstyleinputcell .. code:: python model = WeightDecay(wd=3, lr=0.01) model.board.yscale='log' trainer.fit(model, data) if tab.selected('jax'): print('L2 norm of w:', float(l2_penalty(model.get_w_b(trainer.state)[0]))) if tab.selected('pytorch', 'mxnet', 'tensorflow'): print('L2 norm of w:', float(l2_penalty(model.get_w_b()[0]))) .. raw:: latex \diilbookstyleoutputcell .. parsed-literal:: :class: output L2 norm of w: 0.014530565589666367 .. figure:: output_weight-decay_679df7_56_1.svg ここまでで、単純な線形関数を構成するものについて 一つの考え方に触れた。 しかし、単純な非線形関数であっても、状況ははるかに複雑になりえる。これを理解する上で有用なのが\ `再生核ヒルベルト空間(RKHS) `__\ の概念であり、これを使うことで、 線形関数のために導入された道具を 非線形の文脈に適用できるようになる。 残念ながら、RKHS ベースのアルゴリズムは 大規模で高次元のデータに対しては スケーリングがうまくいかない傾向がある。 この本では、しばしば 重み減衰を深層ネットワークのすべての層に適用する という一般的なヒューリスティックを採用する。 まとめ ------ 正則化は過学習に対処するための一般的な方法である。古典的な正則化手法では、学習時に損失関数へ罰則項を加えることで、学習されたモデルの複雑さを抑える。 モデルを単純に保つための特定の選択肢の一つが、\ :math:`\ell_2` 罰則を使うことである。これにより、ミニバッチ確率的勾配降下法の更新ステップに重み減衰が現れる。 実際には、重み減衰の機能は深層学習フレームワークの最適化器に備わっている。 同じ訓練ループの中でも、異なるパラメータ集合に対して異なる更新挙動を持たせることができる。 演習 ---- 1. この節の推定問題で :math:`\lambda` の値を変えて実験しなさい。訓練精度と検証精度を :math:`\lambda` の関数としてプロットしなさい。何が観察できるか? 2. 検証セットを用いて :math:`\lambda` の最適値を見つけなさい。本当に最適値だろうか? それは重要だろうか? 3. 罰則として :math:`\|\mathbf{w}\|^2` の代わりに :math:`\sum_i |w_i|` を用いた場合(\ :math:`\ell_1` 正則化)、更新方程式はどのようになるか? 4. :math:`\|\mathbf{w}\|^2 = \mathbf{w}^\top \mathbf{w}` であることは分かっている。行列に対しても同様の式を見つけられますか(:numref:`subsec_lin-algebra-norms` のフロベニウスノルムを参照)? 5. 訓練誤差と汎化誤差の関係を復習しなさい。重み減衰に加えて、訓練の増加や適切な複雑さを持つモデルの使用以外に、過学習に対処するのに役立つ方法は何だろうか? 6. ベイズ統計では、事前分布と尤度の積を用いて :math:`P(w \mid x) \propto P(x \mid w) P(w)` により事後分布を得る。\ :math:`P(w)` を正則化とどのように対応づけられますか?