ReformerをTrax (tensor2tensorの後継)で使ってみた
本記事の目的
- Reformerの著者実装 (Trax) を使ってみる!!!
- Traxで学習・推論するための一連の流れをまとめる
Reformerとは
- Transformerの大幅な計算効率の向上・省メモリ化に成功
したモデルです。計算リソースは大きなボトルネックになるので価値が高いです。
また、Transformerは各種BERT系統のベースにもなっています。
BERT系統のモデルは、急速に巨大化しおり、もはや計算リソース的に一般には手がだせなくなってきているので押さえておいて損はないと思います。
既に素晴らしい記事があるので詳しくは割愛します。ぜひ、以下をご覧になって下さい。
参考:
Traxとは
- Google Brain teamがメンテナンスしているtensor2tensorの後継
- Reformerの著者実装が含まれている
- TensorflowとJAXがベースの独立したMLフレームワーク。tf.kerasっぽく書ける
- GPU, TPUの使用が容易
- Document皆無
- 一部のlayerだけTensorflowやpytorchにもってくることはできない(多分)
Traxを使ってみる
以下の様に順を追ってtraxでの開発の流れをまとめます。
- modelの定義
- loss, optimizer, learning rate schedulerの定義
- datasetの準備
- Trainerの定義
- 学習
- 推論パイプラインの作成
なお、本記事のコードはcolabで動かしました。
1. modelの定義
Transformerやresnetなど主流なモデルがいくつか実装されています。(コード)
特に、Reformer関係だと以下の2つが実装されています。
name | 定義 |
---|---|
ReformerLM | language model (decoderのみ) |
Reformer | encoder-decoder model |
Convolution, pooling, RNNなどの主要なlayerやnormalization, activationあたりも豊富に実装されているので、tf.kerasのように層を重ねていくだけでニューラルネットワークモデルを柔軟に実装できます。
ただし、ドキュメントがほとんど何もないですし、ググっても何も引っかからないのでカスタムlayerをゴリゴリに作って、込み入ったことをしようとすると苦労すると思います。
Traxのmodelの実装
traxは独自のkeras.layersの様なlayerが定義されており(コード)、それらを組み合わせてmodelを実装しています。簡単のために、多層パーセプトロンモデルの実装を紹介します。
tl.Serial
をkeras.models.Sequential
と読み替えて貰えばだいたいイメージはつかめると思います。 (kerasになれていれば)
from trax import layers as tl def MLP(d_hidden=512, n_hidden_layers=2, Activation=tl.Relu, n_output_classes=10,mode=''): """A multi-layer feedforward (perceptron) network.""" del mode def DensePlusActivation(): return [tl.Dense(d_hidden), Activation()] return tl.Serial( tl.Flatten(), [DensePlusActivation() for _ in range(n_hidden_layers)], tl.Dense(n_output_classes), tl.LogSoftmax(), )
2. loss, optimizer, learning rate schedulerの定義
基本的なものは関数があります。以下や実装を参考にして使用するものを選択します。
loss - code
name | 定義 |
---|---|
L2Loss | 自乗和誤差 |
CrossEntropyLoss | 交差エントロピー誤差 |
optimizer - code
learning rate scheduling - code
実装を参照して下さい。
3. datasetの準備
datasetはtrax.supervised.Inputs()
インスタンスを作る必要があります。
まず、以下のようなgeneratorを用意します。ただし、trax.supervised.Inputs()
インスタンスはインターフェースを固定するためだけの存在なので、augmentationなどを行いたい場合は、別途tensorflow.data
などのエコシステムを使っていく必要があります。
また、datasetのshapeを(batch_size, ...)となるようにあらかじめ変換してから渡す必要があります。
trax特有の処理は特にないので今回は、サンプルコードをそのまま使います。
import numpy as onp import jax.numpy as np def copy_task(batch_size, vocab_size, length): """This task is to copy a random string w, so the input is 0w0w.""" while True: assert length % 2 == 0 w_length = (length // 2) - 1 w = onp.random.randint(low=1, high=vocab_size-1, size=(batch_size, w_length)) zero = onp.zeros([batch_size, 1], onp.int32) loss_weights = onp.concatenate([onp.zeros((batch_size, w_length)), onp.ones((batch_size, w_length+2))], axis=1) x = onp.concatenate([zero, w, zero, w], axis=1) yield (x, x, loss_weights) # =(Inputs, Targets, Weights)
このgeneratorを渡して以下の様に、trax.supervised.Inputs()
インスタンスを生成します。
batch_size = 16 vocab_size = 32 sentence_length = 10 copy_inputs = trax.supervised.Inputs(lambda _: copy_task(batch_size, vocab_size, sentence_length))
1つのデータの中身は、以下の様になっています。 今回の例は、入力した文章のコピーを推論するタスクです。5つ目のtokenまでが原文であり、その後の5つのtokenがコピーした文章になります。weightsはinputの各tokenのlossへの重み付けを行います。つまり、ここでは4文字目までは推論結果がlossに影響を与えません。5文字目、つまり、原文の最後のtokenを入力し、コピー文章を全て生成する所までの誤差を最小にするような学習を行います。
data_stream = copy_inputs.train_stream(1) inputs, targets, weights = next(data_stream) # inputs[0] -> [ 0, 5, 3, 14, 20, 0, 5, 3, 14, 20] # targets[0] -> [ 0, 5, 3, 14, 20, 0, 5, 3, 14, 20] # weights[0] -> [0., 0., 0., 0., 1., 1., 1., 1., 1., 1.]
4. Trainerの定義
次に、tl.Serial
インスタンスやloss, optimizerの関数定義、datasetを1つのtrax.supervised.Trainer
インスタンスにまとめます。
output_dir = os.path.expanduser('~/train_dir/') # モデルの保存先を指定 !rm -f ~/train_dir/model.pkl # 過去のモデルを削除 trainer = trax.supervised.Trainer( model=tiny_reformer_lm, # モデルの指定 loss_fn=trax.layers.CrossEntropyLoss, # lossの指定 optimizer=trax.optimizers.Adam, # optimizerの指定 lr_schedule=trax.lr.MultifactorSchedule, # lr scheduleの指定 inputs=copy_inputs, # datasetの指定 output_dir=output_dir, # モデルの保存先の指定 has_weights=True)
ここで、モデルは以下の様にReformerLM (Reformerのdecoder部分のみ)を使用することにします。
def tiny_reformer_lm(mode): return trax.models.ReformerLM( d_model=32, d_ff=128, n_layers=2, vocab_size=32, mode=mode)
5. 学習
trainerを使って、以下の様に学習を行うことができます。
n_epochs = 3 train_steps = 500 eval_steps = 2 # 1 epoch内で500 batchを学習に使い、2 batchでvalidationを行う。 # それを、計3 epoch行う。 for _ in range(n_epochs): trainer.train_epoch(train_steps, eval_steps)
以下のようなレポートが表示されます。
... Step 1500: Ran 500 train steps in 1.67 secs Step 1500: Evaluation Step 1500: train accuracy | 0.83854169 Step 1500: train loss | 0.57018900 Step 1500: train neg_log_perplexity | 0.57018900 Step 1500: train weights_per_batch_per_core | 96.00000000 Step 1500: eval accuracy | 0.83333337 Step 1500: eval loss | 0.58189404 Step 1500: eval neg_log_perplexity | 0.58189404 Step 1500: eval weights_per_batch_per_core | 96.00000000 Step 1500: Finished evaluation
余談
TransformerLM
で全く同じ条件で学習させた所、1epochで約2.2 secs程度でした。
入力の文章が短いのであまり大きな差は確認できませんでしたが、確かにReformerの方が速度はでていそうです。
6. 推論パイプラインの作成
最後に、学習したモデルを使って推論パイプラインをつくります。
まず、predict modeでreformerのinstanceを生成し、学習した重みを.init_from_file()
で読み込みます。ただし、途中でstateの初期値を変数にとっておきます。これは、predictの度にstateをcacheしていくので、1文章の処理毎に都度stateを初期化してあげる必要があるためです。
# Initialize model for inference. model_infer = tiny_reformer_lm(mode='predict') # Set up the initial state for sampling. initial_state = model_infer.new_weights_and_state( trax.supervised.trainer_lib.ShapeDtype((1,1), dtype=np.int32))[1] # load pretrained weight model_infer.init_from_file(os.path.join(output_dir, "model.pkl"), weights_only=True)
次に、データを入力させて実際に推論するための関数を用意します。ここで注意点として、仕様上、推論時はtokenを1つずつ入力しなければいけません。つまり文章の長さだけ繰り返し推論させる必要があります。 今回、原文のコピーを推論するタスクなので、前半部分は原文を出力し、後半のみモデルの推論結果を出力することにします。
def prediction(prompt, length=2048): """Sample from the ReformerLM model""" # Token id 0 is the equivalent of a "start" token model_infer.state = initial_state # stateの初期化 cur_inputs = np.zeros((1, 1), dtype=np.int32) # 初期値=0の挿入 prompt = np.asarray(prompt) all_samples = [] for iteration in range(length): logits = model_infer(cur_inputs) if iteration < prompt.shape[0]: cur_samples = onp.array(prompt[iteration], dtype=int) else: logits = onp.array(logits)[0,0,:] probs = onp.exp(logits) cur_samples = onp.random.choice(probs.shape[-1], p=probs[:]) cur_samples = onp.array(cur_samples, dtype=int) all_samples.append(cur_samples) cur_inputs = np.array(cur_samples[None,None]) all_samples = onp.stack(all_samples, -1) return all_samples
出力例は次のようになります。prefixが入力なので、コピーできていることが分かります。
# Run inference prefix = [0, 1, 2, 3, 2, 0] prediction(prefix, 10) # -> array([0, 1, 2, 3, 2, 0, 1, 2, 3, 2])
以上がTrax使用の流れです。
まとめ
- Reformerの著者実装 (Trax) を使ってみました!
- Traxで学習・推論するための一連の流れをまとめました!
Reformerを試すというよりかは1つのMLフレームワークさわってみた、という内容になりました。Reformerの実装を見る限り、フレームワークそのものも使いやすそうでした。Githubに再現実装がいくつかありますが、基本的に著者実装以外はメンテナンスが頼りない印象があるので現状は、(reformer関連だけは) Traxをそのまま使っていこかなーと思っています。
別記事で、以前書いたtransformerを使った記事のReformer版を試したり、 Reformerではなくattention layerだけ使ったモデルをTraxで書いてみたりしよーかなと考えております。
如何せん情報が全然ないので人柱になるつもりで色々試していきたいと思います!!!