.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Module(d2l.nn_Module, d2l.HyperParameters): #@save
"""The base class of models."""
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
def loss(self, y_hat, y):
raise NotImplementedError
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
self.board.draw(x, d2l.numpy(d2l.to(value, d2l.cpu())),
('train_' if train else 'val_') + key,
every_n=int(n))
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
def configure_optimizers(self):
raise NotImplementedError
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Module(d2l.nn_Module, d2l.HyperParameters): #@save
"""The base class of models."""
if tab.selected('mxnet', 'tensorflow'):
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
if tab.selected('tensorflow'):
self.training = None
if tab.selected('jax'):
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
if tab.selected('mxnet', 'tensorflow'):
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
if tab.selected('tensorflow'):
def call(self, X, *args, **kwargs):
if kwargs and "training" in kwargs:
self.training = kwargs['training']
return self.forward(X, *args)
if tab.selected('jax'):
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
if tab.selected('mxnet', 'tensorflow'):
self.board.draw(x, d2l.numpy(value), (
'train_' if train else 'val_') + key, every_n=int(n))
if tab.selected('jax'):
self.board.draw(x, d2l.to(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
if tab.selected('mxnet', 'tensorflow'):
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
if tab.selected('jax'):
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Module(d2l.nn_Module, d2l.HyperParameters): #@save
"""The base class of models."""
if tab.selected('mxnet', 'tensorflow'):
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
if tab.selected('tensorflow'):
self.training = None
if tab.selected('jax'):
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
if tab.selected('mxnet', 'tensorflow'):
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
if tab.selected('tensorflow'):
def call(self, X, *args, **kwargs):
if kwargs and "training" in kwargs:
self.training = kwargs['training']
return self.forward(X, *args)
if tab.selected('jax'):
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
if tab.selected('mxnet', 'tensorflow'):
self.board.draw(x, d2l.numpy(value), (
'train_' if train else 'val_') + key, every_n=int(n))
if tab.selected('jax'):
self.board.draw(x, d2l.to(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
if tab.selected('mxnet', 'tensorflow'):
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
if tab.selected('jax'):
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Module(d2l.nn_Module, d2l.HyperParameters): #@save
"""The base class of models."""
if tab.selected('mxnet', 'tensorflow'):
def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
super().__init__()
self.save_hyperparameters()
self.board = ProgressBoard()
if tab.selected('tensorflow'):
self.training = None
if tab.selected('jax'):
# No need for save_hyperparam when using Python dataclass
plot_train_per_epoch: int = field(default=2, init=False)
plot_valid_per_epoch: int = field(default=1, init=False)
# Use default_factory to make sure new plots are generated on each run
board: ProgressBoard = field(default_factory=lambda: ProgressBoard(),
init=False)
def loss(self, y_hat, y):
raise NotImplementedError
if tab.selected('mxnet', 'tensorflow'):
def forward(self, X):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X)
if tab.selected('tensorflow'):
def call(self, X, *args, **kwargs):
if kwargs and "training" in kwargs:
self.training = kwargs['training']
return self.forward(X, *args)
if tab.selected('jax'):
# JAX & Flax do not have a forward-method-like syntax. Flax uses setup
# and built-in __call__ magic methods for forward pass. Adding here
# for consistency
def forward(self, X, *args, **kwargs):
assert hasattr(self, 'net'), 'Neural network is defined'
return self.net(X, *args, **kwargs)
def __call__(self, X, *args, **kwargs):
return self.forward(X, *args, **kwargs)
def plot(self, key, value, train):
"""Plot a point in animation."""
assert hasattr(self, 'trainer'), 'Trainer is not inited'
self.board.xlabel = 'epoch'
if train:
x = self.trainer.train_batch_idx / \
self.trainer.num_train_batches
n = self.trainer.num_train_batches / \
self.plot_train_per_epoch
else:
x = self.trainer.epoch + 1
n = self.trainer.num_val_batches / \
self.plot_valid_per_epoch
if tab.selected('mxnet', 'tensorflow'):
self.board.draw(x, d2l.numpy(value), (
'train_' if train else 'val_') + key, every_n=int(n))
if tab.selected('jax'):
self.board.draw(x, d2l.to(value, d2l.cpu()),
('train_' if train else 'val_') + key,
every_n=int(n))
if tab.selected('mxnet', 'tensorflow'):
def training_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=True)
return l
def validation_step(self, batch):
l = self.loss(self(*batch[:-1]), batch[-1])
self.plot('loss', l, train=False)
if tab.selected('jax'):
def training_step(self, params, batch, state):
l, grads = jax.value_and_grad(self.loss)(params, batch[:-1],
batch[-1], state)
self.plot("loss", l, train=True)
return l, grads
def validation_step(self, params, batch, state):
l = self.loss(params, batch[:-1], batch[-1], state)
self.plot('loss', l, train=False)
def apply_init(self, dummy_input, key):
"""To be defined later in :numref:`sec_lazy_init`"""
raise NotImplementedError
def configure_optimizers(self):
raise NotImplementedError
.. raw:: html
.. raw:: html
``Module`` が PyTorch のニューラルネットワークの基底クラスである
``nn.Module`` のサブクラスであることに気づくかもしれない。
これはニューラルネットワークを扱うための便利な機能を提供する。たとえば、\ ``forward(self, X)``
のような ``forward`` メソッドを定義すると、インスタンス ``a`` に対して
``a(X)`` と書くだけでこのメソッドを呼び出せる。これは、組み込みの
``__call__`` メソッドが ``forward``
を呼び出すためである。\ ``nn.Module`` についての詳細や例は
:numref:`sec_model_construction` を参照されたい。
.. _oo-design-data:
データ
------
``DataModule``
クラスはデータのための基底クラスである。かなり頻繁に、\ ``__init__``
メソッドはデータの準備に使われる。必要ならダウンロードや前処理も含まれる。\ ``train_dataloader``
は学習データセット用のデータローダーを返す。データローダーは、使われるたびにデータバッチを1つ返す(Pythonの)ジェネレータである。このバッチは
``Module`` の ``training_step``
メソッドに渡され、損失を計算する。検証データセット用のローダーを返す
``val_dataloader``
も任意で用意できる。こちらも同様に動作するが、\ ``Module`` の
``validation_step`` メソッドに渡すデータバッチを返す。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class DataModule(d2l.HyperParameters): #@save
"""The base class of data."""
if tab.selected('mxnet', 'pytorch'):
def __init__(self, root='../data', num_workers=4):
self.save_hyperparameters()
if tab.selected('tensorflow', 'jax'):
def __init__(self, root='../data'):
self.save_hyperparameters()
def get_dataloader(self, train):
raise NotImplementedError
def train_dataloader(self):
return self.get_dataloader(train=True)
def val_dataloader(self):
return self.get_dataloader(train=False)
.. _oo-design-training:
学習
----
``Trainer`` クラスは、\ ``DataModule`` で指定されたデータを使って
``Module`` クラスの学習可能パラメータを学習する。中心となるメソッドは
``fit`` で、2つの引数を受け取る。\ ``model`` は ``Module``
のインスタンス、\ ``data`` は ``DataModule``
のインスタンスである。その後、データセット全体を ``max_epochs``
回繰り返してモデルを学習する。これまでと同様、このメソッドの実装は後の章に回す。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class Trainer(d2l.HyperParameters): #@save
"""The base class for training models with data."""
def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
self.save_hyperparameters()
assert num_gpus == 0, 'No GPU support yet'
def prepare_data(self, data):
self.train_dataloader = data.train_dataloader()
self.val_dataloader = data.val_dataloader()
self.num_train_batches = len(self.train_dataloader)
self.num_val_batches = (len(self.val_dataloader)
if self.val_dataloader is not None else 0)
def prepare_model(self, model):
model.trainer = self
model.board.xlim = [0, self.max_epochs]
self.model = model
if tab.selected('pytorch', 'mxnet', 'tensorflow'):
def fit(self, model, data):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
if tab.selected('jax'):
def fit(self, model, data, key=None):
self.prepare_data(data)
self.prepare_model(model)
self.optim = model.configure_optimizers()
if key is None:
root_key = d2l.get_key()
else:
root_key = key
params_key, dropout_key = jax.random.split(root_key)
key = {'params': params_key, 'dropout': dropout_key}
dummy_input = next(iter(self.train_dataloader))[:-1]
variables = model.apply_init(dummy_input, key=key)
params = variables['params']
if 'batch_stats' in variables.keys():
# Here batch_stats will be used later (e.g., for batch norm)
batch_stats = variables['batch_stats']
else:
batch_stats = {}
# Flax uses optax under the hood for a single state obj TrainState.
# More will be discussed later in the dropout and batch
# normalization section
class TrainState(train_state.TrainState):
batch_stats: Any
dropout_rng: jax.random.PRNGKeyArray
self.state = TrainState.create(apply_fn=model.apply,
params=params,
batch_stats=batch_stats,
dropout_rng=dropout_key,
tx=model.configure_optimizers())
self.epoch = 0
self.train_batch_idx = 0
self.val_batch_idx = 0
for self.epoch in range(self.max_epochs):
self.fit_epoch()
def fit_epoch(self):
raise NotImplementedError
要約
----
深層学習の今後の実装に向けた オブジェクト指向設計を強調するために、
上のクラス群は、オブジェクトが
どのようにデータを保存し、互いにやり取りするかを示しているにすぎない。
この本の残りでは、 ``@add_to_class`` などを通じて、
これらのクラスの実装をさらに充実させていく。 さらに、
これらの完全実装済みクラスは
`D2Lライブラリ