.. _sec_rnn-concise:
リカレントニューラルネットワークの簡潔な実装
============================================
ほとんどのスクラッチ実装と同様に、 :numref:`sec_rnn-scratch` は
各コンポーネントがどのように動作するかを理解できるように設計されていました。
しかし、日常的にRNNを使うときや 本番コードを書くときには、 実装時間
(一般的なモデルや関数に対するライブラリコードを提供してくれるため)
と計算時間 (これらのライブラリ実装を徹底的に最適化してくれるため)
の両方を削減できるライブラリに、より頼りたくなるだろう。 この節では、
深層学習フレームワークが提供する高水準APIを用いて、
同じ言語モデルをより効率的に実装する方法を示す。
まずはこれまでと同様に、 *タイムマシン* データセットを読み込む。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
%load_ext d2lbook.tab
tab.interact_select('mxnet', 'pytorch', 'tensorflow', 'jax')
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import torch as d2l
import torch
from torch import nn
from torch.nn import functional as F
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import mxnet as d2l
from mxnet import np, npx
from mxnet.gluon import nn, rnn
npx.set_np()
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import jax as d2l
from flax import linen as nn
from jax import numpy as jnp
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
from d2l import tensorflow as d2l
import tensorflow as tf
.. raw:: html
.. raw:: html
モデルの定義
------------
以下のクラスを、 高水準APIで実装されたRNNを用いて定義する。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNN(d2l.Module): #@save
"""高水準APIで実装されたRNNモデル。"""
def __init__(self, num_inputs, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = nn.RNN(num_inputs, num_hiddens)
def forward(self, inputs, H=None):
return self.rnn(inputs, H)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNN(d2l.Module): #@save
"""高水準APIで実装されたRNNモデル。"""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = rnn.RNN(num_hiddens)
def forward(self, inputs, H=None):
if H is None:
H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
outputs, (H, ) = self.rnn(inputs, (H, ))
return outputs, H
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNN(nn.Module): #@save
"""高水準APIで実装されたRNNモデル。"""
num_hiddens: int
@nn.compact
def __call__(self, inputs, H=None):
raise NotImplementedError
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNN(d2l.Module): #@save
"""高水準APIで実装されたRNNモデル。"""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = tf.keras.layers.SimpleRNN(
num_hiddens, return_sequences=True, return_state=True,
time_major=True)
def forward(self, inputs, H=None):
outputs, H = self.rnn(inputs, H)
return outputs, H
.. raw:: html
.. raw:: html
:numref:`sec_rnn-scratch` の ``RNNLMScratch`` クラスを継承して、 次の
``RNNLM`` クラスはRNNベースの完全な言語モデルを定義する。
別個の全結合出力層を作成する必要があることに注意せよ。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLM(d2l.RNNLMScratch): #@save
"""高水準APIで実装されたRNNベースの言語モデル。"""
def init_params(self):
self.linear = nn.LazyLinear(self.vocab_size)
def output_layer(self, hiddens):
return d2l.swapaxes(self.linear(hiddens), 0, 1)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLM(d2l.RNNLMScratch): #@save
"""高水準APIで実装されたRNNベースの言語モデル。"""
def init_params(self):
if tab.selected('mxnet'):
self.linear = nn.Dense(self.vocab_size, flatten=False)
self.initialize()
if tab.selected('tensorflow'):
self.linear = tf.keras.layers.Dense(self.vocab_size)
def output_layer(self, hiddens):
if tab.selected('mxnet'):
return d2l.swapaxes(self.linear(hiddens), 0, 1)
if tab.selected('tensorflow'):
return d2l.transpose(self.linear(hiddens), (1, 0, 2))
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLM(d2l.RNNLMScratch): #@save
"""高水準APIで実装されたRNNベースの言語モデル。"""
training: bool = True
def setup(self):
self.linear = nn.Dense(self.vocab_size)
def output_layer(self, hiddens):
return d2l.swapaxes(self.linear(hiddens), 0, 1)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state, self.training)
return self.output_layer(rnn_outputs)
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
class RNNLM(d2l.RNNLMScratch): #@save
"""高水準APIで実装されたRNNベースの言語モデル。"""
def init_params(self):
if tab.selected('mxnet'):
self.linear = nn.Dense(self.vocab_size, flatten=False)
self.initialize()
if tab.selected('tensorflow'):
self.linear = tf.keras.layers.Dense(self.vocab_size)
def output_layer(self, hiddens):
if tab.selected('mxnet'):
return d2l.swapaxes(self.linear(hiddens), 0, 1)
if tab.selected('tensorflow'):
return d2l.transpose(self.linear(hiddens), (1, 0, 2))
.. raw:: html
.. raw:: html
学習と予測
----------
モデルを学習する前に、ランダムな重みで初期化されたモデルを使って予測してみよう。
ネットワークはまだ学習されていないので、 意味をなさない予測を生成する。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
data = d2l.TimeMachine(batch_size=1024, num_steps=32)
if tab.selected('mxnet', 'tensorflow'):
rnn = RNN(num_hiddens=32)
if tab.selected('pytorch'):
rnn = RNN(num_inputs=len(data.vocab), num_hiddens=32)
model = RNNLM(rnn, vocab_size=len(data.vocab), lr=1)
model.predict('it has', 20, data.vocab)
次に、高水準APIを活用して、モデルを学習する。
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
if tab.selected('mxnet', 'pytorch'):
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1, num_gpus=1)
if tab.selected('tensorflow'):
with d2l.try_gpu():
trainer = d2l.Trainer(max_epochs=100, gradient_clip_val=1)
trainer.fit(model, data)
:numref:`sec_rnn-scratch` と比べると、
このモデルは同程度の困惑度を達成するが、
最適化された実装のおかげでより高速に動作する。
これまでと同様に、指定した接頭辞文字列に続く予測トークンを生成できる。
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, d2l.try_gpu())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has it and the pean the'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab, d2l.try_gpu())
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has a dimensions the ti'
.. raw:: html
.. raw:: html
.. raw:: latex
\diilbookstyleinputcell
.. code:: python
model.predict('it has', 20, data.vocab)
.. raw:: latex
\diilbookstyleoutputcell
.. parsed-literal::
:class: output
'it has in the the the that'
.. raw:: html
.. raw:: html
まとめ
------
深層学習フレームワークの高水準APIは、標準的なRNNの実装を提供する。
これらのライブラリを使えば、標準モデルを再実装するために時間を浪費せずに済む。
さらに、 フレームワークの実装はしばしば高度に最適化されているため、
スクラッチ実装と比べて 大幅な(計算)性能向上が得られる。
演習
----
1. 高水準APIを使ってRNNモデルを過学習させることはできるだろうか。
2. :numref:`sec_sequence` の自己回帰モデルをRNNを用いて実装せよ。