3.2. 実装のためのオブジェクト指向設計¶
線形回帰の導入では、 データ、モデル、損失関数、 最適化アルゴリズムを含む さまざまな構成要素を順に見てきた。 実際、 線形回帰は 機械学習モデルの中でも最も単純なものの一つである。 しかし、その学習には、 この本の他のモデルでも必要となる多くの同じ構成要素が使われる。 したがって、 実装の詳細に入る前に、 ここで使ういくつかのAPIを 設計しておく価値がある。 深層学習の構成要素を オブジェクトとして扱えば、 それらのオブジェクトと相互作用を定義する クラスを作ることから始められる。 この実装のためのオブジェクト指向設計は、 説明を大幅に整理するだけでなく、実際のプロジェクトでも有用にお使いいただけるだろう。
PyTorch Lightning
のようなオープンソースライブラリに着想を得て、
大まかには次の3つのクラスを用意したいと考える。 (i) Module
はモデル、損失、最適化手法を含む。 (ii) DataModule
は学習用と検証用のデータローダーを提供する。 (iii) これら2つのクラスは
Trainer
クラスによって統合され、さまざまなハードウェアプラットフォーム上でモデルを学習できるようにする。
この本のコードの大部分は Module と DataModule
を拡張したものである。Trainer
クラスに触れるのは、GPU、CPU、並列学習、最適化アルゴリズムを扱うときだけである。
import time
import numpy as np
from d2l import torch as d2l
import torch
from torch import nn
import time
import numpy as np
from d2l import mxnet as d2l
from mxnet.gluon import nn
from dataclasses import field
from d2l import jax as d2l
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
import numpy as np
import jax
import time
from typing import Any
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import time
import numpy as np
from d2l import tensorflow as d2l
import tensorflow as tf
3.2.1. ユーティリティ¶
Jupyterノートブックでオブジェクト指向プログラミングを簡単にするために、いくつかのユーティリティが必要である。課題の一つは、クラス定義がかなり長いコードブロックになりがちなことである。ノートブックの読みやすさを保つには、説明を挟みながら短いコード片を並べる必要があるが、これはPythonライブラリで一般的なプログラミングスタイルとは相性がよくない。最初の ユーティリティ関数は、クラスが作成された後に、その関数をクラスのメソッドとして登録できるようにする。実際、クラスのインスタンスを作成した後であっても可能である。これにより、クラスの実装を複数のコードブロックに分割できる。
def add_to_class(Class): #@save
"""Register functions as methods in created class."""
def wrapper(obj):
setattr(Class, obj.__name__, obj)
return wrapper
使い方を簡単に見てみよう。do というメソッドを持つクラス A
を実装したいとする。A と do
のコードを同じコードブロックに書く代わりに、まずクラス A
を宣言してインスタンス a を作成できる。
class A:
def __init__(self):
self.b = 1
a = A()
次に、通常どおり do メソッドを定義するが、A
クラスのスコープ内では定義しない。その代わり、このメソッドを
add_to_class でデコレートし、引数としてクラス A
を渡す。こうすることで、このメソッドは、あたかも A
の定義の一部として含まれていたかのように、A
のメンバー変数へアクセスできる。インスタンス a
に対して呼び出すとどうなるか見てみよう。
@add_to_class(A)
def do(self):
print('Class attribute "b" is', self.b)
a.do()
Class attribute "b" is 1
2つ目のユーティリティは、クラスの __init__
メソッドのすべての引数をクラス属性として保存するクラスである。これにより、追加のコードを書かずにコンストラクタの呼び出しシグネチャを暗黙的に拡張できる。
class HyperParameters: #@save
"""The base class of hyperparameters."""
def save_hyperparameters(self, ignore=[]):
raise NotImplemented
その実装は 23.7 章 に回す。使うには、HyperParameters
を継承し、__init__ メソッド内で save_hyperparameters
を呼び出すクラスを定義する。
# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
def __init__(self, a, b, c):
self.save_hyperparameters(ignore=['c'])
print('self.a =', self.a, 'self.b =', self.b)
print('There is no self.c =', not hasattr(self, 'c'))
b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True
最後のユーティリティは、実験の進行状況をその場で対話的に描画できるようにするものである。はるかに強力で複雑な
TensorBoard
に敬意を表して、これを ProgressBoard と名付ける。実装は
23.7 章 に回す。ここでは、まず動作だけ見てみよう。
draw メソッドは、図中に点 (x, y) を描画し、label
を凡例に指定する。オプションの every_n は、図に \(1/n\)
個の点だけを表示することで線を滑らかにする。それらの値は、元の図における
\(n\) 個の近傍点から平均される。
class ProgressBoard(d2l.HyperParameters): #@save
"""The board that plots data points in animation."""
def __init__(self, xlabel=None, ylabel=None, xlim=None,
ylim=None, xscale='linear', yscale='linear',
ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
fig=None, axes=None, figsize=(3.5, 2.5), display=True):
self.save_hyperparameters()
def draw(self, x, y, label, every_n=1):
raise NotImplemented
次の例では、異なる滑らかさで sin と cos
を描画する。このコードブロックを実行すると、線がアニメーションのように伸びていくのが見えるはずである。
board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
board.draw(x, np.sin(x), 'sin', every_n=2)
board.draw(x, np.cos(x), 'cos', every_n=10)
3.2.2. モデル¶
Module
クラスは、これから実装するすべてのモデルの基底クラスである。少なくとも3つのメソッドが必要である。1つ目の
__init__ は学習可能なパラメータを保存し、training_step
メソッドはデータバッチを受け取って損失値を返す。最後に、configure_optimizers
は学習可能なパラメータを更新するために使う最適化手法、またはそのリストを返す。必要に応じて、評価指標を報告するための
validation_step を定義できる。 出力の計算コードを別の forward
メソッドに分けておくと、再利用しやすくなることがある。
class Module(d2l.nn_Module, d2l.HyperParameters): #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, d2l.numpy(d2l.to(value, d2l.cpu())),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
class Module(d2l.nn_Module, d2l.HyperParameters): #@save
"""The base class of models."""
if tab.selected('mxnet', 'tensorflow'):
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
if tab.selected('tensorflow'):
self.training = None
if tab.selected('jax'):
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
if tab.selected('mxnet', 'tensorflow'):
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
if tab.selected('tensorflow'):
def call(self, X, *args, **kwargs):
if kwargs and "training" in kwargs:
self.training = kwargs['training']
return self.forward(X, *args)
if tab.selected('jax'):
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
if tab.selected('mxnet', 'tensorflow'):
self.board.draw(x, d2l.numpy(value), (
'train_' if train else 'val_') + key, every_n=int(n))
if tab.selected('jax'):
self.board.draw(x, d2l.to(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
if tab.selected('mxnet', 'tensorflow'):
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
if tab.selected('jax'):
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
class Module(d2l.nn_Module, d2l.HyperParameters): #@save
"""The base class of models."""
if tab.selected('mxnet', 'tensorflow'):
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
if tab.selected('tensorflow'):
self.training = None
if tab.selected('jax'):
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
if tab.selected('mxnet', 'tensorflow'):
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
if tab.selected('tensorflow'):
def call(self, X, *args, **kwargs):
if kwargs and "training" in kwargs:
self.training = kwargs['training']
return self.forward(X, *args)
if tab.selected('jax'):
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
if tab.selected('mxnet', 'tensorflow'):
self.board.draw(x, d2l.numpy(value), (
'train_' if train else 'val_') + key, every_n=int(n))
if tab.selected('jax'):
self.board.draw(x, d2l.to(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
if tab.selected('mxnet', 'tensorflow'):
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
if tab.selected('jax'):
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
class Module(d2l.nn_Module, d2l.HyperParameters): #@save
"""The base class of models."""
if tab.selected('mxnet', 'tensorflow'):
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
if tab.selected('tensorflow'):
self.training = None
if tab.selected('jax'):
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
if tab.selected('mxnet', 'tensorflow'):
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
if tab.selected('tensorflow'):
def call(self, X, *args, **kwargs):
if kwargs and "training" in kwargs:
self.training = kwargs['training']
return self.forward(X, *args)
if tab.selected('jax'):
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
if tab.selected('mxnet', 'tensorflow'):
self.board.draw(x, d2l.numpy(value), (
'train_' if train else 'val_') + key, every_n=int(n))
if tab.selected('jax'):
self.board.draw(x, d2l.to(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
if tab.selected('mxnet', 'tensorflow'):
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
if tab.selected('jax'):
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
Module が PyTorch のニューラルネットワークの基底クラスである
nn.Module のサブクラスであることに気づくかもしれない。
これはニューラルネットワークを扱うための便利な機能を提供する。たとえば、forward(self, X)
のような forward メソッドを定義すると、インスタンス a に対して
a(X) と書くだけでこのメソッドを呼び出せる。これは、組み込みの
__call__ メソッドが forward
を呼び出すためである。nn.Module についての詳細や例は
6.1 章 を参照されたい。
3.2.3. データ¶
DataModule
クラスはデータのための基底クラスである。かなり頻繁に、__init__
メソッドはデータの準備に使われる。必要ならダウンロードや前処理も含まれる。train_dataloader
は学習データセット用のデータローダーを返す。データローダーは、使われるたびにデータバッチを1つ返す(Pythonの)ジェネレータである。このバッチは
Module の training_step
メソッドに渡され、損失を計算する。検証データセット用のローダーを返す
val_dataloader
も任意で用意できる。こちらも同様に動作するが、Module の
validation_step メソッドに渡すデータバッチを返す。
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
if tab.selected('mxnet', 'pytorch'):
def __init__(self, root='../data', num_workers=4):
self.save_hyperparameters()
if tab.selected('tensorflow', 'jax'):
def __init__(self, root='../data'):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
3.2.4. 学習¶
Trainer クラスは、DataModule で指定されたデータを使って
Module クラスの学習可能パラメータを学習する。中心となるメソッドは
fit で、2つの引数を受け取る。model は Module
のインスタンス、data は DataModule
のインスタンスである。その後、データセット全体を max_epochs
回繰り返してモデルを学習する。これまでと同様、このメソッドの実装は後の章に回す。
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
if tab.selected('pytorch', 'mxnet', 'tensorflow'):
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
if tab.selected('jax'):
def fit(self, model, data, key=None):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
if key is None:
root_key = d2l.get_key()
else:
root_key = key
params_key, dropout_key = jax.random.split(root_key)
key = {'params': params_key, 'dropout': dropout_key}
dummy_input = next(iter(self.train_dataloader))[:-1]
variables = model.apply_init(dummy_input, key=key)
params = variables['params']
if 'batch_stats' in variables.keys():
# Here batch_stats will be used later (e.g., for batch norm)
batch_stats = variables['batch_stats']
else:
batch_stats = {}
# Flax uses optax under the hood for a single state obj TrainState.
# More will be discussed later in the dropout and batch
# normalization section
class TrainState(train_state.TrainState):
batch_stats: Any
dropout_rng: jax.random.PRNGKeyArray
self.state = TrainState.create(apply_fn=model.apply,
params=params,
batch_stats=batch_stats,
dropout_rng=dropout_key,
tx=model.configure_optimizers())
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
3.2.5. 要約¶
深層学習の今後の実装に向けた オブジェクト指向設計を強調するために、
上のクラス群は、オブジェクトが
どのようにデータを保存し、互いにやり取りするかを示しているにすぎない。
この本の残りでは、 @add_to_class などを通じて、
これらのクラスの実装をさらに充実させていく。 さらに、
これらの完全実装済みクラスは
D2Lライブラリ
に保存されており、
深層学習の構造化モデリングを容易にする軽量ツールキットである。
特に、プロジェクト間で多くの構成要素をほとんど変更せずに再利用できるようにする。たとえば、最適化手法だけ、モデルだけ、データセットだけを置き換えることができる。
この程度のモジュール性は、本書全体を通して簡潔さと単純さの面で大きな効果をもたらす(そのためにこれを追加した)。そして、あなた自身のプロジェクトでも同じ効果を発揮するだろう。
3.2.6. 演習¶
D2Lライブラリ に保存されている上記クラスの完全実装を見つけよ。深層学習モデリングにもう少し慣れたら、実装を詳しく読むことを強く勧める。
Bクラスのsave_hyperparameters文を削除せよ。それでもself.aとself.bを表示できるか? 任意:HyperParametersクラスの完全実装まで読み進めたなら、なぜそうなるのか説明できるか?