.. _sec_lazy_init:
遅延初期化
==========
ここまでのところ、ネットワークの構築においてかなり大雑把にやってもうまくいっているように見えたかもしれない。
具体的には、次のような直感に反することを行ってきた。これらは本来うまく動くようには思えないかもしれない。
- 入力次元を指定せずにネットワークアーキテクチャを定義した。
- 直前の層の出力次元を指定せずに層を追加した。
- さらに、モデルが何個のパラメータを持つべきかを決めるのに十分な情報を与える前に、これらのパラメータを「初期化」した。
コードが実際に動いていることに驚くかもしれない。
そもそも、深層学習フレームワークがネットワークの入力次元を知る方法はない。
ここでの工夫は、フレームワークが\ *初期化を遅延*\ し、最初にデータをモデルに通すまで待って、その場で各層のサイズを推論することである。
後で畳み込みニューラルネットワークを扱うときには、この手法はさらに便利になる。
なぜなら、入力次元 (たとえば画像の解像度)
が、その後に続く各層の次元に影響するからである。
したがって、コードを書く時点では次元の値を知らなくてもパラメータを設定できる能力は、モデルの指定やその後の修正を大幅に簡単にしてくれる。
それでは、初期化の仕組みをさらに詳しく見ていこう。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import torch as d2l
import torch
from torch import nn
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from mxnet import np, npx
from mxnet.gluon import nn
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import jax as d2l
from flax import linen as nn
import jax
from jax import numpy as jnp
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
import tensorflow as tf
.. raw:: html
.. raw:: html
まず、MLP をインスタンス化してみよう。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential()
net.add(nn.Dense(256, activation='relu'))
net.add(nn.Dense(10))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = nn.Sequential([nn.Dense(256), nn.relu, nn.Dense(10)])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net = tf.keras.models.Sequential([
tf.keras.layers.Dense(256, activation=tf.nn.relu),
tf.keras.layers.Dense(10),
])
.. raw:: html
.. raw:: html
この時点では、入力次元がまだ不明なので、ネットワークは入力層の重みの次元を知ることができない。
したがって、フレームワークはまだどのパラメータも初期化していない。
以下でパラメータにアクセスしようとして確認してみよう。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net[0].weight
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
print(net.collect_params)
print(net.collect_params())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
256, Activation(relu))
(1): Dense(-1 -> 10, linear)
)>
sequential0_ (
Parameter dense0_weight (shape=(256, -1), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, -1), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
params = net.init(d2l.get_key(), jnp.zeros((2, 20)))
jax.tree_util.tree_map(lambda x: x.shape, params).tree_flatten_with_keys()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
(((DictKey(key='params'),
{'layers_0': {'bias': (256,), 'kernel': (20, 256)},
'layers_2': {'bias': (10,), 'kernel': (256, 10)}}),),
('params',))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
[net.layers[i].get_weights() for i in range(len(net.layers))]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[[], []]
.. raw:: html
.. raw:: html
次に、ネットワークにデータを通して、
フレームワークにようやくパラメータを初期化させてみよう。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = torch.rand(2, 20)
net(X)
net[0].weight.shape
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
torch.Size([256, 20])
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
net.initialize()
net.collect_params()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[07:07:50] ../src/storage/storage.cc:196: Using Pooled (Naive) StorageManager for CPU
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
sequential0_ (
Parameter dense0_weight (shape=(256, -1), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, -1), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Module) #@save
def apply_init(self, dummy_input, key):
params = self.init(key, *dummy_input) # dummy_input tuple unpacked
return params
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = tf.random.uniform((2, 20))
net(X)
[w.shape for w in net.get_weights()]
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
[(20, 256), (256,), (256, 10), (10,)]
.. raw:: html
.. raw:: html
入力次元 20 が分かればすぐに、 フレームワークは 20
の値を代入することで最初の層の重み行列の形状を特定できる。
最初の層の形状が分かると、フレームワークは次の層へ進み、
計算グラフに沿って順に処理し、 すべての形状が分かるまで続ける。
この場合、遅延初期化が必要なのは最初の層だけだが、フレームワークは順次に初期化を行う。
すべてのパラメータ形状が分かると、フレームワークはようやくパラメータを初期化できる。
次のメソッドは、 ダミー入力をネットワークに通して 予備実行を行い、
すべてのパラメータ形状を推論したうえで パラメータを初期化する。
これは、デフォルトのランダム初期化を望まない場合に後で使われる。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
@d2l.add_to_class(d2l.Module) #@save
def apply_init(self, inputs, init=None):
self.forward(*inputs)
if init is not None:
self.net.apply(init)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
X = np.random.uniform(size=(2, 20))
net(X)
net.collect_params()
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
sequential0_ (
Parameter dense0_weight (shape=(256, 20), dtype=float32)
Parameter dense0_bias (shape=(256,), dtype=float32)
Parameter dense1_weight (shape=(10, 256), dtype=float32)
Parameter dense1_bias (shape=(10,), dtype=float32)
)
.. raw:: html
.. raw:: html
要約
----
遅延初期化は便利である。フレームワークがパラメータ形状を自動的に推論できるため、アーキテクチャの修正が容易になり、よくあるエラーの原因を一つ取り除ける。
モデルにデータを通すことで、フレームワークにようやくパラメータを初期化させることができる。
演習
----
1. 最初の層には入力次元を指定するが、その後の層には指定しない場合、どうなるか?
すぐに初期化されるか?
2. 次元が一致しないように指定した場合、どうなるか?
3. 入力の次元が変化する場合、何をする必要があるか? ヒント:
パラメータ共有を見てみよう。