6.3. パラメータの初期化

パラメータへのアクセス方法がわかったので、 次はそれらを適切に初期化する方法を見ていこう。 適切な初期化の必要性については 5.4 章 で議論した。 深層学習フレームワークは、各層に対してデフォルトのランダム初期化を提供している。 しかし、私たちはしばしば、さまざまな別の手順に従って重みを初期化したいことがある。フレームワークは、最も一般的に 使われる手順の多くを提供しており、さらにカスタム初期化子を作成することもできる。

import torch
from torch import nn
from mxnet import init, np, npx
from mxnet.gluon import nn
npx.set_np()
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import tensorflow as tf

デフォルトでは、PyTorch は重み行列とバイアス行列を 入力次元と出力次元に基づいて計算される範囲から一様にサンプリングして初期化する。 PyTorch の nn.init モジュールは、さまざまな 既定の初期化方法を提供している。

net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size=(2, 4))
net(X).shape
torch.Size([2, 1])
net = nn.Sequential()
net.add(nn.Dense(8, activation='relu'))
net.add(nn.Dense(1))
net.initialize()  # Use the default initialization method

X = np.random.uniform(size=(2, 4))
net(X).shape
[07:09:26] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
(2, 1)
net = nn.Sequential([nn.Dense(8), nn.relu, nn.Dense(1)])
X = jax.random.uniform(d2l.get_key(), (2, 4))
params = net.init(d2l.get_key(), X)
net.apply(params, X).shape
(2, 1)
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(4, activation=tf.nn.relu),
    tf.keras.layers.Dense(1),
])

X = tf.random.uniform((2, 4))
net(X).shape
TensorShape([2, 1])

6.3.1. 組み込み初期化

まずは組み込みの初期化子を呼び出してみよう。 以下のコードでは、すべての重みパラメータを 標準偏差 0.01 のガウス乱数で初期化し、バイアスパラメータはゼロにする。

def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean=0, std=0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([ 0.0082, -0.0003,  0.0037, -0.0037]), tensor(0.))
# Here force_reinit ensures that parameters are freshly initialized even if
# they were already initialized previously
net.initialize(init=init.Normal(sigma=0.01), force_reinit=True)
net[0].weight.data()[0]
array([ 0.00354961, -0.00614133,  0.0107317 ,  0.01830765])
weight_init = nn.initializers.normal(0.01)
bias_init = nn.initializers.zeros

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
(Array([-0.00944084,  0.01526781,  0.01000232,  0.01020786], dtype=float32),
 Array(0., dtype=float32))
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4, activation=tf.nn.relu,
        kernel_initializer=tf.random_normal_initializer(mean=0, stddev=0.01),
        bias_initializer=tf.zeros_initializer()),
    tf.keras.layers.Dense(1)])

net(X)
net.weights[0], net.weights[1]
(<tf.Variable 'dense_2/kernel:0' shape=(4, 4) dtype=float32, numpy=
 array([[-0.0095617 ,  0.00378142, -0.01553275, -0.00099767],
        [-0.00153092,  0.00172021,  0.01361554, -0.02167145],
        [ 0.00229247, -0.00045055,  0.01760593, -0.00316169],
        [-0.0145327 ,  0.01879665, -0.00351739, -0.003721  ]],
       dtype=float32)>,
 <tf.Variable 'dense_2/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)

すべてのパラメータを ある定数値(たとえば 1)に初期化することもできる。

def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]
(tensor([1., 1., 1., 1.]), tensor(0.))
net.initialize(init=init.Constant(1), force_reinit=True)
net[0].weight.data()[0]
array([1., 1., 1., 1.])
weight_init = nn.initializers.constant(1)

net = nn.Sequential([nn.Dense(8, kernel_init=weight_init, bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=weight_init, bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
layer_0 = params['params']['layers_0']
layer_0['kernel'][:, 0], layer_0['bias'][0]
(Array([1., 1., 1., 1.], dtype=float32), Array(0., dtype=float32))
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4, activation=tf.nn.relu,
        kernel_initializer=tf.keras.initializers.Constant(1),
        bias_initializer=tf.zeros_initializer()),
    tf.keras.layers.Dense(1),
])

net(X)
net.weights[0], net.weights[1]
(<tf.Variable 'dense_4/kernel:0' shape=(4, 4) dtype=float32, numpy=
 array([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], dtype=float32)>,
 <tf.Variable 'dense_4/bias:0' shape=(4,) dtype=float32, numpy=array([0., 0., 0., 0.], dtype=float32)>)

特定のブロックに対して異なる初期化子を適用することもできる。 たとえば、以下では第1層を Xavier 初期化子で初期化し、第2層を 定数 42 で初期化する。

def init_xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)

def init_42(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 42)

net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)
tensor([-0.4081, -0.0622, -0.6532,  0.2202])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])
net[0].weight.initialize(init=init.Xavier(), force_reinit=True)
net[1].initialize(init=init.Constant(42), force_reinit=True)
print(net[0].weight.data()[0])
print(net[1].weight.data())
[-0.26102373  0.15249556 -0.19274211 -0.24742058]
[[42. 42. 42. 42. 42. 42. 42. 42.]]
net = nn.Sequential([nn.Dense(8, kernel_init=nn.initializers.xavier_uniform(),
                              bias_init=bias_init),
                     nn.relu,
                     nn.Dense(1, kernel_init=nn.initializers.constant(42),
                              bias_init=bias_init)])

params = net.init(jax.random.PRNGKey(d2l.get_seed()), X)
params['params']['layers_0']['kernel'][:, 0], params['params']['layers_2']['kernel']
(Array([-0.3472612 , -0.14502545,  0.5215495 , -0.6690141 ], dtype=float32),
 Array([[42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.],
        [42.]], dtype=float32))
net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4,
        activation=tf.nn.relu,
        kernel_initializer=tf.keras.initializers.GlorotUniform()),
    tf.keras.layers.Dense(
        1, kernel_initializer=tf.keras.initializers.Constant(42)),
])

net(X)
print(net.layers[1].weights[0])
print(net.layers[2].weights[0])
<tf.Variable 'dense_6/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[ 0.62192386,  0.567151  , -0.83411425,  0.07775843],
       [-0.02987593,  0.69998187, -0.6425164 ,  0.07918394],
       [-0.81396955,  0.48030943, -0.45407206,  0.56459075],
       [ 0.5427986 , -0.27270615, -0.62875164, -0.8139801 ]],
      dtype=float32)>
<tf.Variable 'dense_7/kernel:0' shape=(4, 1) dtype=float32, numpy=
array([[42.],
       [42.],
       [42.],
       [42.]], dtype=float32)>

6.3.1.1. カスタム初期化

ときには、必要な初期化方法が 深層学習フレームワークに用意されていないことがある。 以下の例では、任意の重みパラメータ \(w\) に対して、次の奇妙な分布を用いる初期化子を定義する。

(6.3.1)\[\begin{split}\begin{aligned} w \sim \begin{cases} U(5, 10) & \textrm{ with probability } \frac{1}{4} \\ 0 & \textrm{ with probability } \frac{1}{2} \\ U(-10, -5) & \textrm{ with probability } \frac{1}{4} \end{cases} \end{aligned}\end{split}\]

ここでも、net に適用する my_init 関数を実装する。

def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape)
                        for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]
Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])
tensor([[ 0.0000,  8.1632,  7.1743, -0.0000],
        [-7.6266,  9.0984, -0.0000,  5.0482]], grad_fn=<SliceBackward0>)
class MyInit(init.Initializer):
    def _init_weight(self, name, data):
        print('Init', name, data.shape)
        data[:] = np.random.uniform(-10, 10, data.shape)
        data *= np.abs(data) >= 5

net.initialize(MyInit(), force_reinit=True)
net[0].weight.data()[:2]
Init dense0_weight (8, 4)
Init dense1_weight (1, 8)
array([[-6.0683527,  8.991421 , -0.       ,  0.       ],
       [ 6.4198647, -9.728567 , -8.057975 ,  0.       ]])
def my_init(key, shape, dtype=jnp.float_):
    data = jax.random.uniform(key, shape, minval=-10, maxval=10)
    return data * (jnp.abs(data) >= 5)

net = nn.Sequential([nn.Dense(8, kernel_init=my_init), nn.relu, nn.Dense(1)])
params = net.init(d2l.get_key(), X)
print(params['params']['layers_0']['kernel'][:, :2])
[[ 0.        -9.883148 ]
 [ 9.662153  -8.952053 ]
 [-8.446951   8.958147 ]
 [ 0.         6.0564566]]
class MyInit(tf.keras.initializers.Initializer):
    def __call__(self, shape, dtype=None):
        data=tf.random.uniform(shape, -10, 10, dtype=dtype)
        factor=(tf.abs(data) >= 5)
        factor=tf.cast(factor, tf.float32)
        return data * factor

net = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(
        4,
        activation=tf.nn.relu,
        kernel_initializer=MyInit()),
    tf.keras.layers.Dense(1),
])

net(X)
print(net.layers[1].weights[0])
<tf.Variable 'dense_8/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[-9.507499 ,  0.       , -7.383604 ,  7.671566 ],
       [ 8.4090805, -7.128086 ,  0.       ,  0.       ],
       [ 0.       ,  7.5970764,  0.       ,  0.       ],
       [-5.3449297,  6.064953 , -6.477902 ,  0.       ]], dtype=float32)>

パラメータを直接設定するという選択肢も常にあることに注意してほしい。

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]
tensor([42.0000,  9.1632,  8.1743,  1.0000])
net[0].weight.data()[:] += 1
net[0].weight.data()[0, 0] = 42
net[0].weight.data()[0]
array([42.      ,  9.991421,  1.      ,  1.      ])
net.layers[1].weights[0][:].assign(net.layers[1].weights[0] + 1)
net.layers[1].weights[0][0, 0].assign(42)
net.layers[1].weights[0]
<tf.Variable 'dense_8/kernel:0' shape=(4, 4) dtype=float32, numpy=
array([[42.       ,  1.       , -6.383604 ,  8.671566 ],
       [ 9.4090805, -6.128086 ,  1.       ,  1.       ],
       [ 1.       ,  8.597076 ,  1.       ,  1.       ],
       [-4.3449297,  7.064953 , -5.477902 ,  1.       ]], dtype=float32)>

6.3.2. まとめ

組み込み初期化子とカスタム初期化子を使ってパラメータを初期化できる。

6.3.3. 演習

さらに多くの組み込み初期化子についてオンラインドキュメントを調べてみよう。