ONNXファイルから不要な枝を削ってMNISTの推論を高速化してみる

この記事の中のソースコードは全てhttps://github.com/akawashiro/sonnxにあります。

概要

背景

機械学習の学習済みモデルを小さなデバイスで動かす、というのが最近流行っているそうです。機械学習では、学習には大きな計算コストがかかりますが、推論はそれほど大きな計算コストがかかりません。このため、学習だけを別のコンピュータで行っておいて、実際の推論は小さなデバイスで行うということが可能です。

ただし、推論だけでもそれなりに計算資源が必要です。そこで、学習済みのモデルの高速化が重要になります。Raspberry Piに搭載されているGPUを使うIdeinとか有名です。

僕も学習済みモデルの推論を高速化できそうな方法を思いついたので実験してみました。

イデア

今回はMNISTを分類する学習済みモデルを高速化します。今回使用するモデルは次の図のようなものです。画像は28*28(=784)pxなので入力は784個、出力は各数字の確率なので10個あり、中間層が2つ挟まっています。各層間は全結合しており、活性化関数としてReluを使います。

このモデルを教師データを使って学習すると、枝の重みが変わってこんな感じになります。

僕のアイデアは学習後のネットワークから重みの小さい枝を取り去ってもちゃんと動くんじゃないか、というものです。重みの小さい枝を取り去るとこんな感じになります。

枝の本数が少なくなれば、推論を高速化できそうな気がします。

イデアの裏付け

実際に学習後のモデルで赤丸で囲った部分の重みの分布を確認します。

分布はこのようになっています。

重みが0の部分が非常に多いです。まず、学習済みモデルから重みが0の枝を削除しても推論結果に影響しないはずです。また、グラフが左右対称になっているので、絶対値の小さい順に削除していけば、各パーセプトロンへの入力はそれほど変化しない気がします。

手法

やることは非常に単純です。ニューラルネットワーク中の各層間枝の重みの統計を取り、重み上位何%かを残して残りを削除します。つまり、何らかの方法で学習済みのモデルから枝の重みのデータを取り出し、枝をカットし、さらに加工後のモデルのデータを使って推論できるようにします。

幸い今はONNXという良いものがあります。ONNXとは学習済みモデルのデータを出力する形式の一つで、多くのフレームワークが対応しています。

今回はChainerで書いたモデルからONNXデータを出力し、そのデータを加工することにしました。また加工後のデータはC++で書いた俺俺ONNXランタイムに実行してもらうことにしました。

纏めると、
1. Chainerでニューラルネットワークを書いて学習する
2. 学習済みのニューラルネットワークからONNXデータを出力する
3. ONNXデータを俺俺ONNXランタイムに読み込んで加工、実行する

となります。一つづつ何をやるのかを説明します。

1. Chainerでニューラルネットワークを書いて学習する

2. 学習済みのニューラルネットワークからONNXデータを出力する

1と2は簡単です。onnx-chainerを使えばすぐにできます。

python3 learn_mnist.py

mnist.onnxというファイルができます。

3. ONNXデータを俺俺ONNXランタイムに読み込んで加工、実行する

ここが大変でした。ディープラーニングフレームワークからのONNXモデルの出力は多くの人が試しているのですが、出力したONNXモデルをチューニングしようとする人はほとんどいないようです。

3.1 ONNXデータを解析する

とりあえずnetronというONNXの可視化ツールでmnist.onnxを可視化してみました。

GemmはGeneral matrix multiplyの略です。各Gemmノードは行列BとベクトルCを持ち、ベクトルxを入力としてBx+Cを出力します。Relu活性化関数です。

Reluはmax(0,x)で定義されている関数のでONNXから抽出する必要は無く、各Gemmノードの行列BとベクトルCの情報だけをを抽出できれば良いです。

今回は各GemmノードのBCをテキストファイルとして抽出します。

> python3 analyze_mnist_onnx.py

は次のファイルを出力します。

  • *************_matrix.txt
    mnist.onnxの全てのGemmノードのB行列とC行列
  • *************_matrix.png
    各行列の中の重みの分布
  • mnist.onnx.json
    ONNXファイルをJSONに変換したもの
  • mnist_train.txt
  • mnist_test.txt
    MNISTをC++から読み込みやすくするためにテキストファイルに変換したもの

*************_matrix.txtがどのGemmノードに対応するのかはmnist.onnx.jsonを睨むとわかります(←ここ超不親切)。

3.2 抽出したデータを加工、実行する
g++ -O3 -mtune=native -march=native -mfpmath=both sonnx.cpp && ./a.out

で出力した重みのデータを読み込み、推論を実行します。sonnx.cppは簡単なONNXランタイムになっており、mnist.onnxから抽出した行列のデータとMNISTの画像データのテキストファイルから推論を行います。デフォルトではmnist_test.txtの10000枚について推論を行います。

出力はこのようになります。

accuracy, time
0.9817000031, 15.93887043
compress_ratio, accuracy, time
0, 0.9817000031, 36.66616821
0.05, 0.9817000031, 34.61940384
0.1, 0.9818000197, 32.64707184
0.15, 0.9818000197, 30.78090286
0.2, 0.9817000031, 29.01914406
..........................

2行目は読み込んだデータを加工せずで隣接行列表現で保持したときの推論の精度と計算時間、4行目以降は読み込んだデータを加工したときの推論の圧縮率(compress_ratio)、精度と計算時間です。圧縮率が0のときは全く加工しないのと同じ、圧縮率が0.3のときは枝の重みの絶対値が小さい30%を削除したときの結果です。圧縮率が1になると全ての枝が削除されます。

データを加工して枝を削除するとき、sonnx.cppでは行列データを非零成分のインデックスとその値の組の配列として保持しています。これは枝を削除したときに保持するデータ量が減らし、高速に計算するためです。

結果

圧縮率を変化させたときの精度と計算時間のグラフです。圧縮率を0.8まで上げても推論の精度が変わっていません。これは学習済みモデル中の8割の枝を削除しても、推論精度が保てるということです。驚きですね!

一方、計算時間はほぼ枝の数に反比例しています。これは予想通りでした。

表にしてみます。推論精度を1%犠牲にするだけで2倍も高速化できています。やったね!

圧縮率 推論精度 計算時間(秒)
0 0.981 15.9
0.8 0.971 7.04

※ 圧縮率0のときの値は重みのデータを隣接行列表現で保持したときのものです。

おまけ

最後にonnxruntimeで元の学習済みモデルを実行したときの時間を計ってみます。

> python3 analyze_mnist_onnx.py
Accuracy rate =  0.9817 , time =  3.285153865814209

え、3秒?? 普通に考えると16秒ぐらいになるはずです。なぜこんなに速いんだ...

理由は色々あると思いますが、僕が思いつくのは以下の2つです。

  • AVXなどのSIMD拡張命令を使っている
    (僕もAVXを使おうとしましたが、scatter命令が僕のPCで使えないので諦めました。)
  • CPUのキャッシュに乗るようなプログラムになっている

関連記事・研究

まとめ

  • ONNXランタイムを自作したらナイーブな実装の二倍の速度で推論ができた
  • onnxruntimeには勝てなかった