メモ帳

python, juliaで機械学習をやっていく

ReformerをTrax (tensor2tensorの後継)で使ってみた

本記事の目的

  • Reformerの著者実装 (Trax) を使ってみる!!!
  • Traxで学習・推論するための一連の流れをまとめる

Reformerとは

  • Transformerの大幅な計算効率の向上・省メモリ化に成功

したモデルです。計算リソースは大きなボトルネックになるので価値が高いです。
また、Transformerは各種BERT系統のベースにもなっています。 BERT系統のモデルは、急速に巨大化しおり、もはや計算リソース的に一般には手がだせなくなってきているので押さえておいて損はないと思います。

既に素晴らしい記事があるので詳しくは割愛します。ぜひ、以下をご覧になって下さい。

参考:

Traxとは

github.com

  • Google Brain teamがメンテナンスしているtensor2tensorの後継
  • Reformerの著者実装が含まれている
  • TensorflowとJAXがベースの独立したMLフレームワーク。tf.kerasっぽく書ける
  • GPU, TPUの使用が容易
  • Document皆無
  • 一部のlayerだけTensorflowやpytorchにもってくることはできない(多分)

Traxを使ってみる

以下の様に順を追ってtraxでの開発の流れをまとめます。

  1. modelの定義
  2. loss, optimizer, learning rate schedulerの定義
  3. datasetの準備
  4. Trainerの定義
  5. 学習
  6. 推論パイプラインの作成

なお、本記事のコードはcolabで動かしました。

1. modelの定義

Transformerやresnetなど主流なモデルがいくつか実装されています。(コード)
特に、Reformer関係だと以下の2つが実装されています。

name 定義
ReformerLM language model (decoderのみ)
Reformer encoder-decoder model

Convolution, pooling, RNNなどの主要なlayerやnormalization, activationあたりも豊富に実装されているので、tf.kerasのように層を重ねていくだけでニューラルネットワークモデルを柔軟に実装できます。

ただし、ドキュメントがほとんど何もないですし、ググっても何も引っかからないのでカスタムlayerをゴリゴリに作って、込み入ったことをしようとすると苦労すると思います。

Traxのmodelの実装

traxは独自のkeras.layersの様なlayerが定義されており(コード)、それらを組み合わせてmodelを実装しています。簡単のために、多層パーセプトロンモデルの実装を紹介します。
tl.Serialkeras.models.Sequentialと読み替えて貰えばだいたいイメージはつかめると思います。 (kerasになれていれば)

from trax import layers as tl

def MLP(d_hidden=512,
        n_hidden_layers=2,
        Activation=tl.Relu,
        n_output_classes=10,mode=''):
    """A multi-layer feedforward (perceptron) network."""
    del mode

    def DensePlusActivation():
        return [tl.Dense(d_hidden), Activation()]

    return tl.Serial(
        tl.Flatten(),
        [DensePlusActivation() for _ in range(n_hidden_layers)],
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )

2. loss, optimizer, learning rate schedulerの定義

基本的なものは関数があります。以下や実装を参考にして使用するものを選択します。

loss - code

name 定義
L2Loss 自乗和誤差
CrossEntropyLoss 交差エントロピー誤差

optimizer - code

  • adafactor
  • adam
  • momentum
  • rms_prop
  • sgd
  • sm3

learning rate scheduling - code

実装を参照して下さい。

3. datasetの準備

datasetはtrax.supervised.Inputs()インスタンスを作る必要があります。 まず、以下のようなgeneratorを用意します。ただし、trax.supervised.Inputs()インスタンスはインターフェースを固定するためだけの存在なので、augmentationなどを行いたい場合は、別途tensorflow.dataなどのエコシステムを使っていく必要があります。
また、datasetのshapeを(batch_size, ...)となるようにあらかじめ変換してから渡す必要があります。
trax特有の処理は特にないので今回は、サンプルコードをそのまま使います。

import numpy as onp
import jax.numpy as np

def copy_task(batch_size, vocab_size, length):
    """This task is to copy a random string w, so the input is 0w0w."""
    while True:
        assert length % 2 == 0
        w_length = (length // 2) - 1
        w = onp.random.randint(low=1, high=vocab_size-1, size=(batch_size, w_length))
        zero = onp.zeros([batch_size, 1], onp.int32)
        loss_weights = onp.concatenate([onp.zeros((batch_size, w_length)),
                                    onp.ones((batch_size, w_length+2))], axis=1)
        x = onp.concatenate([zero, w, zero, w], axis=1)
        yield (x, x, loss_weights)  # =(Inputs, Targets, Weights)

このgeneratorを渡して以下の様に、trax.supervised.Inputs()インスタンスを生成します。

batch_size = 16
vocab_size = 32
sentence_length = 10
copy_inputs = trax.supervised.Inputs(lambda _: copy_task(batch_size, vocab_size, sentence_length))

1つのデータの中身は、以下の様になっています。 今回の例は、入力した文章のコピーを推論するタスクです。5つ目のtokenまでが原文であり、その後の5つのtokenがコピーした文章になります。weightsはinputの各tokenのlossへの重み付けを行います。つまり、ここでは4文字目までは推論結果がlossに影響を与えません。5文字目、つまり、原文の最後のtokenを入力し、コピー文章を全て生成する所までの誤差を最小にするような学習を行います。

data_stream = copy_inputs.train_stream(1)
inputs, targets, weights = next(data_stream)

# inputs[0] ->  [ 0, 5, 3, 14, 20, 0, 5, 3, 14, 20]
# targets[0] -> [ 0, 5, 3, 14, 20, 0, 5, 3, 14, 20]
# weights[0] -> [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]

4. Trainerの定義

次に、tl.Serial インスタンスやloss, optimizerの関数定義、datasetを1つのtrax.supervised.Trainerインスタンスにまとめます。

output_dir = os.path.expanduser('~/train_dir/') # モデルの保存先を指定
!rm -f ~/train_dir/model.pkl                    # 過去のモデルを削除

trainer = trax.supervised.Trainer(
    model=tiny_reformer_lm,                   # モデルの指定
    loss_fn=trax.layers.CrossEntropyLoss,     # lossの指定
    optimizer=trax.optimizers.Adam,           # optimizerの指定
    lr_schedule=trax.lr.MultifactorSchedule,  # lr scheduleの指定
    inputs=copy_inputs,                       # datasetの指定
    output_dir=output_dir,                    # モデルの保存先の指定
    has_weights=True)

ここで、モデルは以下の様にReformerLM (Reformerのdecoder部分のみ)を使用することにします。

def tiny_reformer_lm(mode):
    return trax.models.ReformerLM(
      d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)

5. 学習

trainerを使って、以下の様に学習を行うことができます。

n_epochs  = 3
train_steps = 500
eval_steps = 2
# 1 epoch内で500 batchを学習に使い、2 batchでvalidationを行う。
# それを、計3 epoch行う。
for _ in range(n_epochs):
    trainer.train_epoch(train_steps, eval_steps)

以下のようなレポートが表示されます。

...
Step   1500: Ran 500 train steps in 1.67 secs
Step   1500: Evaluation
Step   1500: train                   accuracy |  0.83854169
Step   1500: train                       loss |  0.57018900
Step   1500: train         neg_log_perplexity |  0.57018900
Step   1500: train weights_per_batch_per_core |  96.00000000
Step   1500: eval                    accuracy |  0.83333337
Step   1500: eval                        loss |  0.58189404
Step   1500: eval          neg_log_perplexity |  0.58189404
Step   1500: eval  weights_per_batch_per_core |  96.00000000
Step   1500: Finished evaluation

余談

TransformerLMで全く同じ条件で学習させた所、1epochで約2.2 secs程度でした。 入力の文章が短いのであまり大きな差は確認できませんでしたが、確かにReformerの方が速度はでていそうです。

6. 推論パイプラインの作成

最後に、学習したモデルを使って推論パイプラインをつくります。

まず、predict modeでreformerのinstanceを生成し、学習した重みを.init_from_file()で読み込みます。ただし、途中でstateの初期値を変数にとっておきます。これは、predictの度にstateをcacheしていくので、1文章の処理毎に都度stateを初期化してあげる必要があるためです。

# Initialize model for inference.
model_infer = tiny_reformer_lm(mode='predict')

# Set up the initial state for sampling.
initial_state = model_infer.new_weights_and_state(
    trax.supervised.trainer_lib.ShapeDtype((1,1), dtype=np.int32))[1]

# load pretrained weight
model_infer.init_from_file(os.path.join(output_dir, "model.pkl"), weights_only=True)

次に、データを入力させて実際に推論するための関数を用意します。ここで注意点として、仕様上、推論時はtokenを1つずつ入力しなければいけません。つまり文章の長さだけ繰り返し推論させる必要があります。 今回、原文のコピーを推論するタスクなので、前半部分は原文を出力し、後半のみモデルの推論結果を出力することにします。

def prediction(prompt, length=2048):
    """Sample from the ReformerLM model"""
    # Token id 0 is the equivalent of a "start" token
    model_infer.state = initial_state             # stateの初期化
    cur_inputs = np.zeros((1, 1), dtype=np.int32) # 初期値=0の挿入
    prompt = np.asarray(prompt)
    all_samples = []

    for iteration in range(length):
        logits = model_infer(cur_inputs)
    
        if iteration < prompt.shape[0]:
            cur_samples = onp.array(prompt[iteration], dtype=int)
        else:
            logits = onp.array(logits)[0,0,:]
            probs = onp.exp(logits)
            cur_samples = onp.random.choice(probs.shape[-1], p=probs[:])
            cur_samples = onp.array(cur_samples, dtype=int)
        all_samples.append(cur_samples)
        cur_inputs = np.array(cur_samples[None,None])
    all_samples = onp.stack(all_samples, -1)

    return all_samples

出力例は次のようになります。prefixが入力なので、コピーできていることが分かります。

# Run inference
prefix = [0, 1, 2, 3, 2, 0] 
prediction(prefix, 10)
# -> array([0, 1, 2, 3, 2, 0, 1, 2, 3, 2])

以上がTrax使用の流れです。

まとめ

  • Reformerの著者実装 (Trax) を使ってみました!
  • Traxで学習・推論するための一連の流れをまとめました!

Reformerを試すというよりかは1つのMLフレームワークさわってみた、という内容になりました。Reformerの実装を見る限り、フレームワークそのものも使いやすそうでした。Githubに再現実装がいくつかありますが、基本的に著者実装以外はメンテナンスが頼りない印象があるので現状は、(reformer関連だけは) Traxをそのまま使っていこかなーと思っています。

別記事で、以前書いたtransformerを使った記事のReformer版を試したり、 Reformerではなくattention layerだけ使ったモデルをTraxで書いてみたりしよーかなと考えております。

如何せん情報が全然ないので人柱になるつもりで色々試していきたいと思います!!!