パラメータの初期化
==================
パラメータへのアクセス方法がわかったので、
次はそれらを適切に初期化する方法を見ていこう。
適切な初期化の必要性については :numref:`sec_numerical_stability`
で議論した。
深層学習フレームワークは、各層に対してデフォルトのランダム初期化を提供している。
しかし、私たちはしばしば、さまざまな別の手順に従って重みを初期化したいことがある。フレームワークは、最も一般的に
使われる手順の多くを提供しており、さらにカスタム初期化子を作成することもできる。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import torch
from torch import nn
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import init, np, npx
from mxnet.gluon import nn
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
.. raw:: html
.. raw:: html
デフォルトでは、PyTorch は重み行列とバイアス行列を
入力次元と出力次元に基づいて計算される範囲から一様にサンプリングして初期化する。
PyTorch の ``nn.init`` モジュールは、さまざまな
既定の初期化方法を提供している。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([2, 1])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize() # Use the default initialization method
X = np.random.uniform(size=(2, 4))
net(X).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[07:09:26] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(2, 1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(4, activation=tf.nn.relu),
tf.keras.layers.Dense(1),
])
X = tf.random.uniform((2, 4))
net(X).shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
TensorShape([2, 1])
.. raw:: html
.. raw:: html
組み込み初期化
--------------
まずは組み込みの初期化子を呼び出してみよう。
以下のコードでは、すべての重みパラメータを 標準偏差 0.01
のガウス乱数で初期化し、バイアスパラメータはゼロにする。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def init_normal(module):
if type(module) == nn.Linear:
nn.init.normal_(module.weight, mean=0, std=0.01)
nn.init.zeros_(module.bias)
net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([ 0.0082, -0.0003, 0.0037, -0.0037]), tensor(0.))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
# Here force_reinit ensures that parameters are freshly initialized even if
# they were already initialized previously
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([ 0.00354961, -0.00614133, 0.0107317 , 0.01830765])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros
net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([-0.00944084, 0.01526781, 0.01000232, 0.01020786], dtype=float32),
Array(0., dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4, activation=tf.nn.relu,
kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.01),
bias_initializer=tf.zeros_initializer()),
tf.keras.layers.Dense(1)])
net(X)
net.weights[0], net.weights[1]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
)
.. raw:: html
.. raw:: html
すべてのパラメータを ある定数値(たとえば 1)に初期化することもできる。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def init_constant(module):
if type(module) == nn.Linear:
nn.init.constant_(module.weight, 1)
nn.init.zeros_(module.bias)
net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(tensor([1., 1., 1., 1.]), tensor(0.))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([1., 1., 1., 1.])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
weight_init = nn.initializers.constant(1)
net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4, activation=tf.nn.relu,
kernel_initializer=tf.keras.initializers.Constant(1),
bias_initializer=tf.zeros_initializer()),
tf.keras.layers.Dense(1),
])
net(X)
net.weights[0], net.weights[1]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(,
)
.. raw:: html
.. raw:: html
特定のブロックに対して異なる初期化子を適用することもできる。
たとえば、以下では第1層を Xavier 初期化子で初期化し、第2層を 定数 42
で初期化する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def init_xavier(module):
if type(module) == nn.Linear:
nn.init.xavier_uniform_(module.weight)
def init_42(module):
if type(module) == nn.Linear:
nn.init.constant_(module.weight, 42)
net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([-0.4081, -0.0622, -0.6532, 0.2202])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[0].weight.data()[0])
print(net[1].weight.data())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[-0.26102373 0.15249556 -0.19274211 -0.24742058]
[[42. 42. 42. 42. 42. 42. 42. 42.]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
bias_init=bias_init),
nn.relu,
nn.Dense(1, kernel_init=nn.initializers.constant(42),
bias_init=bias_init)])
params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel']
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(Array([-0.3472612 , -0.14502545, 0.5215495 , -0.6690141 ], dtype=float32),
Array([[42.],
[42.],
[42.],
[42.],
[42.],
[42.],
[42.],
[42.]], dtype=float32))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4,
activation=tf.nn.relu,
kernel_initializer=tf.keras.initializers.GlorotUniform()),
tf.keras.layers.Dense(
1, kernel_initializer=tf.keras.initializers.Constant(42)),
])
net(X)
print(net.layers[1].weights[0])
print(net.layers[2].weights[0])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
カスタム初期化
~~~~~~~~~~~~~~
ときには、必要な初期化方法が
深層学習フレームワークに用意されていないことがある。
以下の例では、任意の重みパラメータ :math:`w`
に対して、次の奇妙な分布を用いる初期化子を定義する。
.. math::
\begin{aligned}
w \sim \begin{cases}
U(5, 10) & \textrm{ with probability } \frac{1}{4} \\
0 & \textrm{ with probability } \frac{1}{2} \\
U(-10, -5) & \textrm{ with probability } \frac{1}{4}
\end{cases}
\end{aligned}
ここでも、\ ``net`` に適用する ``my_init`` 関数を実装する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def my_init(module):
if type(module) == nn.Linear:
print("Init", *[(name, param.shape)
for name, param in module.named_parameters()][0])
nn.init.uniform_(module.weight, -10, 10)
module.weight.data *= module.weight.data.abs() >= 5
net.apply(my_init)
net[0].weight[:2]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([[ 0.0000, 8.1632, 7.1743, -0.0000],
[-7.6266, 9.0984, -0.0000, 5.0482]], grad_fn=)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MyInit(init.Initializer):
def _init_weight(self, name, data):
print('Init', name, data.shape)
data[:] = np.random.uniform(-10, 10, data.shape)
data *= np.abs(data) >= 5
net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[:2]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
Init dense0_weight (8, 4)
Init dense1_weight (1, 8)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([[-6.0683527, 8.991421 , -0. , 0. ],
[ 6.4198647, -9.728567 , -8.057975 , 0. ]])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
def my_init(key, shape, dtype=jnp.float_):
data = jax.random.uniform(key, shape, minval=-10, maxval=10)
return data * (jnp.abs(data) >= 5)
net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(d2l.get_key(), X)
print(params['params']['layers_0']['kernel'][:, :2])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[[ 0. -9.883148 ]
[ 9.662153 -8.952053 ]
[-8.446951 8.958147 ]
[ 0. 6.0564566]]
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class MyInit(tf.keras.initializers.Initializer):
def __call__(self, shape, dtype=None):
data=tf.random.uniform(shape, -10, 10, dtype=dtype)
factor=(tf.abs(data) >= 5)
factor=tf.cast(factor, tf.float32)
return data * factor
net = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
4,
activation=tf.nn.relu,
kernel_initializer=MyInit()),
tf.keras.layers.Dense(1),
])
net(X)
print(net.layers[1].weights[0])
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
パラメータを直接設定するという選択肢も常にあることに注意してほしい。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
tensor([42.0000, 9.1632, 8.1743, 1.0000])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
array([42. , 9.991421, 1. , 1. ])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1)
net.layers[1].weights[0][0, 0].assign(42)
net.layers[1].weights[0]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
まとめ
------
組み込み初期化子とカスタム初期化子を使ってパラメータを初期化できる。
演習
----
さらに多くの組み込み初期化子についてオンラインドキュメントを調べてみよう。