セイレンチュウ..

chainerでLSTMを実装し正弦波を学習させる

概要

LSTMをchainerで実装します。

5/8追記:バッチをなぜかバラしてfor文で1つずつ学習させるという愚行を行なっていたので修正しました。

この記事の対象



本編

LSTMとは

Long Short-Time Memory(長期短期記憶)の略です。 通常のニューラルネットワーク再帰的な処理に変更し時系列データに対応したものがRNNですが、それの一種です。RNNで問題となる、長い時系列データを扱った際に起こる勾配消失を、ゲートをつけることによって解決した手法です。
詳しくは「ゼロから作るディープラーニング② -自然言語処理編-」などに載っています。


実装

chainerについて

深層学習用Pythonライブラリはtensorflowやpytorchなどいろいろありますが、今回はPFNのchainerを利用します。chainerは2019年になってチュートリアルが公開されたので、わからないという方はまずそれを読んでみてください。

tutorials.chainer.org

chainerのNNの学習の仕組みをざっくりいうと、「datasetsから作ったiteratorネットワークをセットしたoptimizerを引数にしてupdaterを作り、それをtrainerで管理する」です。(いきなり単語ばかり並べられても訳がわからないと思いますが、詳しくは上で紹介したチュートリアルに載っているので読んでください。)

通常のNNと異なりLSTMでは時系列データを扱うため、既存のStandardUpdaterやIteratorクラスを継承しオーバーライドしてあげる必要があります。 順に書いていきます。

学習まで

まずは必要なものをまとめてimportしておきます。

import numpy as np
import math
import matplotlib.pyplot as plt

import chainer
import chainer.functions as F
import chainer.links as L
from chainer import reporter, training, datasets, iterators, optimizers, serializers
from chainer.training import extensions
from chainer.datasets import TupleDataset

モデル定義をします。これは通常のNNと基本的には同じです。
今回はl2層にLSTMをいれたシンプルなものです。最後の全結合層でLSTMの出力を1つの値に集約し、学習させます。別のやり方として、LSTMの最後の出力を用いて損失を計算させる方法もあります。
__init__でレイヤーを定義し、__call__に順伝搬処理を書きます。
一つだけ異なるのはreset_state()でLSTM層のリセットを定義しておくことです。これはLSTM層の出力とメモリセルを消去する関数らしいです(重みの値以外を初期化するイメージであってるはず)。学習済みモデルを使って検証するときなどに使います。

class lstm(chainer.Chain):
    def __init__(self,n_mid=10,n_out=1):
        super(lstm,self).__init__()
        with self.init_scope():
            self.l1 = L.Linear(None,n_mid)
            self.l2 = L.LSTM(n_mid,n_mid)
            self.l3 = L.Linear(n_mid,n_out)

    def reset_state(self):
        self.l2.reset_state()

    def __call__(self,x):
        h = self.l1(x)
        h = self.l2(h)
        h = self.l3(h)
        return h

イテレータを作ります。

class LSTM_Iterator(chainer.dataset.Iterator):
    def __init__(self, dataset, batch_size=10, seq_len=5, repeat=True):
        self.seq_length = seq_len
        self.dataset = dataset
        self.nsamples = len(dataset)

        self.batch_size = batch_size
        self.repeat = repeat

        self.epoch = 0
        self.iteration = 0
        self.offsets = np.random.randint(0, self.nsamples,size=self.batch_size)

        self.is_new_epoch = False

    def __next__(self):
        if not self.repeat and self.iteration * self.batch_size >= self.nsamples:
            raise StopIteration

        x,t = self.get_data()

        self.iteration += 1

        epoch = self.iteration * self.batch_size // self.nsamples

        self.is_new_epoch = self.epoch < epoch

        if self.is_new_epoch:
            self.epoch = epoch
            self.offsets = np.random.randint(0, self.nsamples,size=self.batch_size)
        return list(zip(x, t))

    @property
    def epoch_detail(self):
        return self.iteration * self.batch_size / len(self.dataset)

    def get_data(self):
        x = []
        for offset in self.offsets:
            tmp = []
            for i in range(self.seq_length):

                tmp.append(self.dataset[(offset + self.iteration - self.seq_length + (i+1)) % len(self.dataset)])
            x.append(tmp)
        t = [self.dataset[(offset + self.iteration + 1) % len(self.dataset)]
                for offset in self.offsets]
        return x,t

    def serialize(self, serializer):
        self.iteration = serializer('iteration', self.iteration)
        self.epoch     = serializer('epoch', self.epoch)

self.offsetsはバッチ学習の際に使うマスク、batch_sizeはバッチサイズ、seq_lengthは学習データ一つの長さです。get_data(self)は与えられたdatasetsに対して時系列を崩さず、seq_lengthの大きさのデータをxとして切り取り、その一つ後に続くデータを正解データとしてtに格納していく関数です。

Updaterも作ります。

class LSTM_updater(training.StandardUpdater):
    def __init__(self, train_iter, optimizer, device):
        super(LSTM_updater, self).__init__(train_iter, optimizer, device=device)
        self.seq_length = train_iter.seq_length
        self.batch_size = train_iter.batch_size

    def update_core(self):
        loss = 0

        train_iter = self.get_iterator('main')
        optimizer = self.get_optimizer('main')


        batch = train_iter.__next__()
        x, t  = self.converter(batch, self.device)
        loss += optimizer.target(chainer.Variable(x), chainer.Variable(t))

        optimizer.target.zerograds()
        loss.backward()
        loss.unchain_backward()
        optimizer.update()

先ほど作ったIteratorクラスの__next__()関数でbatchを読み込み、batchごとに計算した損失を合算し、逆伝搬させます。 (chainer/train_ptb.py at master · chainer/chainer · GitHubupdate_core(self)のループの様子が違い少し自信がありませんが、おそらくあっていると思います)
さて、それでは必要なクラスが作り終わったのでモデルを組んでいきましょう。

net = L.Classifier(lstm(),lossfun=F.mean_squared_error)
net.compute_accuracy = False
optimizer = optimizers.Adam()
optimizer.setup(net)

L.Classifierを利用します。分類問題に使うラッパーですが、lossfunをMSEなどに変えてnet.compute_accuracy = Falseとしてやることで今回のような予測にも使えます。
optimizerはAdamにします。(別に他のものでもいいです。)

次に訓練に使うデータを作ります。今回は学習できて当たり前の正弦波を扱います。

# データ作成
n_data = 600
sin_data = []
for i in range(n_data+1):
    sin_data.append(math.sin(i/50*math.pi))

# データセット
n_train = 500
n_test  = n_data-n_train


sin_data = np.array(sin_data).astype(np.float32)

x_train, x_test = sin_data[:n_train],sin_data[n_train:]


train = TupleDataset(x_train)
test  = TupleDataset(x_test)

n_seq = 5
train_iter = LSTM_Iterator(train, batch_size = 5, seq_len = n_seq)
test_iter  = LSTM_Iterator(test,  batch_size = 5, seq_len = n_seq, repeat = False)

コードを見てもらえばわかりますが、π/50刻みで600目盛分、すなわち[0,12π]の範囲のy=sin xのグラフが学習データです。(これをさらにtrainデータとvalidation用のtestデータに分けています。)
chainerで行う通常の深層学習のように、TupleDatasetにして、その後作ったLSTM_Iteratorに入れています。 n_seqseq_lengthと同じです。

いよいよ学習させます。

updater = LSTM_updater(train_iter, optimizer, -1)
trainer = training.Trainer(updater, (30, 'epoch'), out='results/lstm_result')

eval_model = net.copy()
eval_rnn = eval_model.predictor
eval_rnn.train = False
eval_rnn.reset_state()
trainer.extend(extensions.Evaluator(test_iter, eval_model, device=-1), name='val')

trainer.extend(extensions.LogReport(trigger=(1,'epoch'),log_name='log'))
trainer.extend(extensions.PrintReport(['epoch', 'main/loss', 'val/main/loss']))
trainer.extend(training.extensions.PlotReport(['main/loss','val/main/loss'], x_key='epoch', file_name='loss.png'))
trainer.extend(extensions.dump_graph('main/loss'))
# trainer.extend(extensions.ProgressBar())

trainer.run()
serializers.save_npz('lstm.npz',net)

ここら辺は特に説明することがないので割愛します。eval_modelを作った際にreset_state()しているのに注意してください。

さて、ここまで書いてきたコードをまとめて書いて実行することで学習が行われます。次は学習結果を検証してみます。

検証

正しく学習できていたかチェックします。まず、出力された損失の推移グラフを確認します。

f:id:alumi-tan:20190508014838p:plain
lossの折れ線プロット
epochを重ねるごとに、train,test共に損失が減っていることがわかります。 次に学習済みモデルを用いて、正弦波データセットの任意の連続した5個(今回はn_seq=5のため)を与えると次の値が予測できるのかを検証します。

def val1():
    net.predictor.reset_state()
    y = np.array([])
    for i in range(n_seq):
        new = net.predictor(chainer.Variable(x_train[100-n_seq+i:100-n_seq+i+n_seq].reshape(-1,n_seq)))
        y = np.append(y,new.data[-1])
    for i in range(n_train-n_seq):
        new = net.predictor(chainer.Variable(x_train[i:i+n_seq].reshape(-1,n_seq)))
        y = np.append(y,new.data[-1])

    plt.plot(range(len(y)),y,label='lstm')
    plt.plot(range(n_train),x_train,label='train')
    plt.legend()
    plt.show()

val1()

初めの5個のデータも予想するために、1周期後(2π=100*π/50)の5個前のデータ、すなわち95個目のデータを使っているという小細工がありますが、基本的には「π/50刻みで連続する5個の正弦波データを与えると次の値が予測できるのか」を確かめています。
結果は以下のようになります。

f:id:alumi-tan:20190508014527p:plain
val1

かなりうまく学習できているようです。

ちなみに、同じようにもう一つ検証してみます。上で試したものは1つ先の未来を予想することでした。では、その予測した未来を使ってさらに1つ先の未来を予想する、ということを繰り返していくことは可能なのでしょうか? 以下のコードを使います。

def val2():
    net.predictor.reset_state()
    y = x_train[:n_seq]

    for i in range(n_train):
        new = net.predictor(chainer.Variable(y[-n_seq:].reshape(-1,n_seq)))
        y = np.append(y,new.data[0])

    plt.plot(range(len(y)),y,label='lstm')
    plt.plot(range(n_train),x_train,label='train')
    plt.legend()
    plt.show()

val2()

結果は、散々かと思いきや意外と良い感じです。

f:id:alumi-tan:20190508014618p:plain
val2

当たり前ですが、真値とのズレが積み重なっていくので、今回のように簡単なデータでも苦戦します。周波数も少し変わっているようです。ここら辺はまた研究してみます。

まとめ

今回はLSTMをchainerで実装してみました。LSTMは自然言語処理モデルにもよく使われているので次はそっちを試してみたいです。




雑談

今回のLSTMですが、GW中ずっと悩んで、やっとできたー!って感じです。chainer自体チュートリアルが公開されたのをきっかけに勉強を始め、苦労したのですが、LSTMはさらに難しかったです。ネットにもあまりサンプルが転がってなかったため、chainerのコード自体を読み込んだり公式のサンプルを参考にしたりして、なんとか書き上げました。間違っているところや改善できる箇所があれば是非教えていただきたいです。(このコードだと計算グラフがうまく出力できないのでヘンテコな実装をしている可能性は大いにあります)

最後に参考にさせていただいたサイトを貼っておきます。ありがとうございました。

chainer/train_ptb.py at master · chainer/chainer · GitHub

LSTMにsin波を覚えてもらう(chainer trainerの速習) - Qiita

ChainerでLSTMを学習する手順を整理してみた | 自調自考の旅