3.2. 実装のためのオブジェクト指向設計

線形回帰の導入では、 データ、モデル、損失関数、 最適化アルゴリズムを含む さまざまな構成要素を順に見てきた。 実際、 線形回帰は 機械学習モデルの中でも最も単純なものの一つである。 しかし、その学習には、 この本の他のモデルでも必要となる多くの同じ構成要素が使われる。 したがって、 実装の詳細に入る前に、 ここで使ういくつかのAPIを 設計しておく価値がある。 深層学習の構成要素を オブジェクトとして扱えば、 それらのオブジェクトと相互作用を定義する クラスを作ることから始められる。 この実装のためのオブジェクト指向設計は、 説明を大幅に整理するだけでなく、実際のプロジェクトでも有用にお使いいただけるだろう。

PyTorch Lightning のようなオープンソースライブラリに着想を得て、 大まかには次の3つのクラスを用意したいと考える。 (i) Module はモデル、損失、最適化手法を含む。 (ii) DataModule は学習用と検証用のデータローダーを提供する。 (iii) これら2つのクラスは Trainer クラスによって統合され、さまざまなハードウェアプラットフォーム上でモデルを学習できるようにする。 この本のコードの大部分は ModuleDataModule を拡張したものである。Trainer クラスに触れるのは、GPU、CPU、並列学習、最適化アルゴリズムを扱うときだけである。

import time
import numpy as np
from d2l import torch as d2l
import torch
from torch import nn
import time
import numpy as np
from d2l import mxnet as d2l
from mxnet.gluon import nn
from dataclasses import field
from d2l import jax as d2l
from flax import linen as nn
from flax.training import train_state
from jax import numpy as jnp
import numpy as np
import jax
import time
from typing import Any
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
import time
import numpy as np
from d2l import tensorflow as d2l
import tensorflow as tf

3.2.1. ユーティリティ

Jupyterノートブックでオブジェクト指向プログラミングを簡単にするために、いくつかのユーティリティが必要である。課題の一つは、クラス定義がかなり長いコードブロックになりがちなことである。ノートブックの読みやすさを保つには、説明を挟みながら短いコード片を並べる必要があるが、これはPythonライブラリで一般的なプログラミングスタイルとは相性がよくない。最初の ユーティリティ関数は、クラスが作成されたに、その関数をクラスのメソッドとして登録できるようにする。実際、クラスのインスタンスを作成したであっても可能である。これにより、クラスの実装を複数のコードブロックに分割できる。

def add_to_class(Class):  #@save
    """Register functions as methods in created class."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

使い方を簡単に見てみよう。do というメソッドを持つクラス A を実装したいとする。Ado のコードを同じコードブロックに書く代わりに、まずクラス A を宣言してインスタンス a を作成できる。

class A:
    def __init__(self):
        self.b = 1

a = A()

次に、通常どおり do メソッドを定義するが、A クラスのスコープ内では定義しない。その代わり、このメソッドを add_to_class でデコレートし、引数としてクラス A を渡す。こうすることで、このメソッドは、あたかも A の定義の一部として含まれていたかのように、A のメンバー変数へアクセスできる。インスタンス a に対して呼び出すとどうなるか見てみよう。

@add_to_class(A)
def do(self):
    print('Class attribute "b" is', self.b)

a.do()
Class attribute "b" is 1

2つ目のユーティリティは、クラスの __init__ メソッドのすべての引数をクラス属性として保存するクラスである。これにより、追加のコードを書かずにコンストラクタの呼び出しシグネチャを暗黙的に拡張できる。

class HyperParameters:  #@save
    """The base class of hyperparameters."""
    def save_hyperparameters(self, ignore=[]):
        raise NotImplemented

その実装は 23.7 章 に回す。使うには、HyperParameters を継承し、__init__ メソッド内で save_hyperparameters を呼び出すクラスを定義する。

# Call the fully implemented HyperParameters class saved in d2l
class B(d2l.HyperParameters):
    def __init__(self, a, b, c):
        self.save_hyperparameters(ignore=['c'])
        print('self.a =', self.a, 'self.b =', self.b)
        print('There is no self.c =', not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)
self.a = 1 self.b = 2
There is no self.c = True

最後のユーティリティは、実験の進行状況をその場で対話的に描画できるようにするものである。はるかに強力で複雑な TensorBoard に敬意を表して、これを ProgressBoard と名付ける。実装は 23.7 章 に回す。ここでは、まず動作だけ見てみよう。

draw メソッドは、図中に点 (x, y) を描画し、label を凡例に指定する。オプションの every_n は、図に \(1/n\) 個の点だけを表示することで線を滑らかにする。それらの値は、元の図における \(n\) 個の近傍点から平均される。

class ProgressBoard(d2l.HyperParameters):  #@save
    """The board that plots data points in animation."""
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplemented

次の例では、異なる滑らかさで sincos を描画する。このコードブロックを実行すると、線がアニメーションのように伸びていくのが見えるはずである。

board = d2l.ProgressBoard('x')
for x in np.arange(0, 10, 0.1):
    board.draw(x, np.sin(x), 'sin', every_n=2)
    board.draw(x, np.cos(x), 'cos', every_n=10)
../_images/output_oo-design_0fbe35_28_0.svg

3.2.2. モデル

Module クラスは、これから実装するすべてのモデルの基底クラスである。少なくとも3つのメソッドが必要である。1つ目の __init__ は学習可能なパラメータを保存し、training_step メソッドはデータバッチを受け取って損失値を返す。最後に、configure_optimizers は学習可能なパラメータを更新するために使う最適化手法、またはそのリストを返す。必要に応じて、評価指標を報告するための validation_step を定義できる。 出力の計算コードを別の forward メソッドに分けておくと、再利用しやすくなることがある。

class Module(d2l.nn_Module, d2l.HyperParameters):  #@save
    """The base class of models."""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()

    def loss(self, y_hat, y):
        raise NotImplementedError

    def forward(self, X):
        assert hasattr(self, 'net'), 'Neural network is defined'
        return self.net(X)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, d2l.numpy(d2l.to(value, d2l.cpu())),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))

    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError
class Module(d2l.nn_Module, d2l.HyperParameters):  #@save
    """The base class of models."""
    if tab.selected('mxnet', 'tensorflow'):
        def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
            super().__init__()
            self.save_hyperparameters()
            self.board = ProgressBoard()
        if tab.selected('tensorflow'):
            self.training = None

    if tab.selected('jax'):
        # No need for save_hyperparam when using Python dataclass
        plot_train_per_epoch: int = field(default=2, init=False)
        plot_valid_per_epoch: int = field(default=1, init=False)
        # Use default_factory to make sure new plots are generated on each run
        board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
                                     init=False)

    def loss(self, y_hat, y):
        raise NotImplementedError

    if tab.selected('mxnet', 'tensorflow'):
        def forward(self, X):
            assert hasattr(self, 'net'), 'Neural network is defined'
            return self.net(X)

    if tab.selected('tensorflow'):
        def call(self, X, *args, **kwargs):
            if kwargs and "training" in kwargs:
                self.training = kwargs['training']
            return self.forward(X, *args)

    if tab.selected('jax'):
        # JAX & Flax do not have a forward-method-like syntax. Flax uses setup
        # and built-in __call__ magic methods for forward pass. Adding here
        # for consistency
        def forward(self, X, *args, **kwargs):
            assert hasattr(self, 'net'), 'Neural network is defined'
            return self.net(X, *args, **kwargs)

        def __call__(self, X, *args, **kwargs):
            return self.forward(X, *args, **kwargs)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        if tab.selected('mxnet', 'tensorflow'):
            self.board.draw(x, d2l.numpy(value), (
                'train_' if train else 'val_') + key, every_n=int(n))
        if tab.selected('jax'):
            self.board.draw(x, d2l.to(value, d2l.cpu()),
                            ('train_' if train else 'val_') + key,
                            every_n=int(n))

    if tab.selected('mxnet', 'tensorflow'):
        def training_step(self, batch):
            l = self.loss(self(*batch[:-1]), batch[-1])
            self.plot('loss', l, train=True)
            return l

        def validation_step(self, batch):
            l = self.loss(self(*batch[:-1]), batch[-1])
            self.plot('loss', l, train=False)

    if tab.selected('jax'):
        def training_step(self, params, batch, state):
            l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
                                                     batch[-1], state)
            self.plot("loss", l, train=True)
            return l, grads

        def validation_step(self, params, batch, state):
            l = self.loss(params, batch[:-1], batch[-1], state)
            self.plot('loss', l, train=False)

        def apply_init(self, dummy_input, key):
            """To be defined later in :numref:`sec_lazy_init`"""
            raise NotImplementedError

    def configure_optimizers(self):
        raise NotImplementedError
class Module(d2l.nn_Module, d2l.HyperParameters):  #@save
    """The base class of models."""
    if tab.selected('mxnet', 'tensorflow'):
        def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
            super().__init__()
            self.save_hyperparameters()
            self.board = ProgressBoard()
        if tab.selected('tensorflow'):
            self.training = None

    if tab.selected('jax'):
        # No need for save_hyperparam when using Python dataclass
        plot_train_per_epoch: int = field(default=2, init=False)
        plot_valid_per_epoch: int = field(default=1, init=False)
        # Use default_factory to make sure new plots are generated on each run
        board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
                                     init=False)

    def loss(self, y_hat, y):
        raise NotImplementedError

    if tab.selected('mxnet', 'tensorflow'):
        def forward(self, X):
            assert hasattr(self, 'net'), 'Neural network is defined'
            return self.net(X)

    if tab.selected('tensorflow'):
        def call(self, X, *args, **kwargs):
            if kwargs and "training" in kwargs:
                self.training = kwargs['training']
            return self.forward(X, *args)

    if tab.selected('jax'):
        # JAX & Flax do not have a forward-method-like syntax. Flax uses setup
        # and built-in __call__ magic methods for forward pass. Adding here
        # for consistency
        def forward(self, X, *args, **kwargs):
            assert hasattr(self, 'net'), 'Neural network is defined'
            return self.net(X, *args, **kwargs)

        def __call__(self, X, *args, **kwargs):
            return self.forward(X, *args, **kwargs)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        if tab.selected('mxnet', 'tensorflow'):
            self.board.draw(x, d2l.numpy(value), (
                'train_' if train else 'val_') + key, every_n=int(n))
        if tab.selected('jax'):
            self.board.draw(x, d2l.to(value, d2l.cpu()),
                            ('train_' if train else 'val_') + key,
                            every_n=int(n))

    if tab.selected('mxnet', 'tensorflow'):
        def training_step(self, batch):
            l = self.loss(self(*batch[:-1]), batch[-1])
            self.plot('loss', l, train=True)
            return l

        def validation_step(self, batch):
            l = self.loss(self(*batch[:-1]), batch[-1])
            self.plot('loss', l, train=False)

    if tab.selected('jax'):
        def training_step(self, params, batch, state):
            l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
                                                     batch[-1], state)
            self.plot("loss", l, train=True)
            return l, grads

        def validation_step(self, params, batch, state):
            l = self.loss(params, batch[:-1], batch[-1], state)
            self.plot('loss', l, train=False)

        def apply_init(self, dummy_input, key):
            """To be defined later in :numref:`sec_lazy_init`"""
            raise NotImplementedError

    def configure_optimizers(self):
        raise NotImplementedError
class Module(d2l.nn_Module, d2l.HyperParameters):  #@save
    """The base class of models."""
    if tab.selected('mxnet', 'tensorflow'):
        def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
            super().__init__()
            self.save_hyperparameters()
            self.board = ProgressBoard()
        if tab.selected('tensorflow'):
            self.training = None

    if tab.selected('jax'):
        # No need for save_hyperparam when using Python dataclass
        plot_train_per_epoch: int = field(default=2, init=False)
        plot_valid_per_epoch: int = field(default=1, init=False)
        # Use default_factory to make sure new plots are generated on each run
        board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
                                     init=False)

    def loss(self, y_hat, y):
        raise NotImplementedError

    if tab.selected('mxnet', 'tensorflow'):
        def forward(self, X):
            assert hasattr(self, 'net'), 'Neural network is defined'
            return self.net(X)

    if tab.selected('tensorflow'):
        def call(self, X, *args, **kwargs):
            if kwargs and "training" in kwargs:
                self.training = kwargs['training']
            return self.forward(X, *args)

    if tab.selected('jax'):
        # JAX & Flax do not have a forward-method-like syntax. Flax uses setup
        # and built-in __call__ magic methods for forward pass. Adding here
        # for consistency
        def forward(self, X, *args, **kwargs):
            assert hasattr(self, 'net'), 'Neural network is defined'
            return self.net(X, *args, **kwargs)

        def __call__(self, X, *args, **kwargs):
            return self.forward(X, *args, **kwargs)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        if tab.selected('mxnet', 'tensorflow'):
            self.board.draw(x, d2l.numpy(value), (
                'train_' if train else 'val_') + key, every_n=int(n))
        if tab.selected('jax'):
            self.board.draw(x, d2l.to(value, d2l.cpu()),
                            ('train_' if train else 'val_') + key,
                            every_n=int(n))

    if tab.selected('mxnet', 'tensorflow'):
        def training_step(self, batch):
            l = self.loss(self(*batch[:-1]), batch[-1])
            self.plot('loss', l, train=True)
            return l

        def validation_step(self, batch):
            l = self.loss(self(*batch[:-1]), batch[-1])
            self.plot('loss', l, train=False)

    if tab.selected('jax'):
        def training_step(self, params, batch, state):
            l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
                                                     batch[-1], state)
            self.plot("loss", l, train=True)
            return l, grads

        def validation_step(self, params, batch, state):
            l = self.loss(params, batch[:-1], batch[-1], state)
            self.plot('loss', l, train=False)

        def apply_init(self, dummy_input, key):
            """To be defined later in :numref:`sec_lazy_init`"""
            raise NotImplementedError

    def configure_optimizers(self):
        raise NotImplementedError

Module が PyTorch のニューラルネットワークの基底クラスである nn.Module のサブクラスであることに気づくかもしれない。 これはニューラルネットワークを扱うための便利な機能を提供する。たとえば、forward(self, X) のような forward メソッドを定義すると、インスタンス a に対して a(X) と書くだけでこのメソッドを呼び出せる。これは、組み込みの __call__ メソッドが forward を呼び出すためである。nn.Module についての詳細や例は 6.1 章 を参照されたい。

3.2.3. データ

DataModule クラスはデータのための基底クラスである。かなり頻繁に、__init__ メソッドはデータの準備に使われる。必要ならダウンロードや前処理も含まれる。train_dataloader は学習データセット用のデータローダーを返す。データローダーは、使われるたびにデータバッチを1つ返す(Pythonの)ジェネレータである。このバッチは Moduletraining_step メソッドに渡され、損失を計算する。検証データセット用のローダーを返す val_dataloader も任意で用意できる。こちらも同様に動作するが、Modulevalidation_step メソッドに渡すデータバッチを返す。

class DataModule(d2l.HyperParameters):  #@save
    """The base class of data."""
    if tab.selected('mxnet', 'pytorch'):
        def __init__(self, root='../data', num_workers=4):
            self.save_hyperparameters()

    if tab.selected('tensorflow', 'jax'):
        def __init__(self, root='../data'):
            self.save_hyperparameters()

    def get_dataloader(self, train):
        raise NotImplementedError

    def train_dataloader(self):
        return self.get_dataloader(train=True)

    def val_dataloader(self):
        return self.get_dataloader(train=False)

3.2.4. 学習

Trainer クラスは、DataModule で指定されたデータを使って Module クラスの学習可能パラメータを学習する。中心となるメソッドは fit で、2つの引数を受け取る。modelModule のインスタンス、dataDataModule のインスタンスである。その後、データセット全体を max_epochs 回繰り返してモデルを学習する。これまでと同様、このメソッドの実装は後の章に回す。

class Trainer(d2l.HyperParameters):  #@save
    """The base class for training models with data."""
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    if tab.selected('pytorch', 'mxnet', 'tensorflow'):
        def fit(self, model, data):
            self.prepare_data(data)
            self.prepare_model(model)
            self.optim = model.configure_optimizers()
            self.epoch = 0
            self.train_batch_idx = 0
            self.val_batch_idx = 0
            for self.epoch in range(self.max_epochs):
                self.fit_epoch()

    if tab.selected('jax'):
        def fit(self, model, data, key=None):
            self.prepare_data(data)
            self.prepare_model(model)
            self.optim = model.configure_optimizers()

            if key is None:
                root_key = d2l.get_key()
            else:
                root_key = key
            params_key, dropout_key = jax.random.split(root_key)
            key = {'params': params_key, 'dropout': dropout_key}

            dummy_input = next(iter(self.train_dataloader))[:-1]
            variables = model.apply_init(dummy_input, key=key)
            params = variables['params']

            if 'batch_stats' in variables.keys():
                # Here batch_stats will be used later (e.g., for batch norm)
                batch_stats = variables['batch_stats']
            else:
                batch_stats = {}

            # Flax uses optax under the hood for a single state obj TrainState.
            # More will be discussed later in the dropout and batch
            # normalization section
            class TrainState(train_state.TrainState):
                batch_stats: Any
                dropout_rng: jax.random.PRNGKeyArray

            self.state = TrainState.create(apply_fn=model.apply,
                                           params=params,
                                           batch_stats=batch_stats,
                                           dropout_rng=dropout_key,
                                           tx=model.configure_optimizers())
            self.epoch = 0
            self.train_batch_idx = 0
            self.val_batch_idx = 0
            for self.epoch in range(self.max_epochs):
                self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

3.2.5. 要約

深層学習の今後の実装に向けた オブジェクト指向設計を強調するために、 上のクラス群は、オブジェクトが どのようにデータを保存し、互いにやり取りするかを示しているにすぎない。 この本の残りでは、 @add_to_class などを通じて、 これらのクラスの実装をさらに充実させていく。 さらに、 これらの完全実装済みクラスは D2Lライブラリ に保存されており、 深層学習の構造化モデリングを容易にする軽量ツールキットである。 特に、プロジェクト間で多くの構成要素をほとんど変更せずに再利用できるようにする。たとえば、最適化手法だけ、モデルだけ、データセットだけを置き換えることができる。 この程度のモジュール性は、本書全体を通して簡潔さと単純さの面で大きな効果をもたらす(そのためにこれを追加した)。そして、あなた自身のプロジェクトでも同じ効果を発揮するだろう。

3.2.6. 演習

  1. D2Lライブラリ に保存されている上記クラスの完全実装を見つけよ。深層学習モデリングにもう少し慣れたら、実装を詳しく読むことを強く勧める。

  2. B クラスの save_hyperparameters 文を削除せよ。それでも self.aself.b を表示できるか? 任意: HyperParameters クラスの完全実装まで読み進めたなら、なぜそうなるのか説明できるか?