formalized-egison -- Egisonの型安全性の証明に向けて

概要

Egisonのパターンマッチの動作を定理証明支援系Coq上で定義し、「Egisonの処理系が項Mを値vに評価する」という命題をCoq上で表現できるようにしました。

動機

Egisonの中核的な機能はユーザが拡張可能なパターンマッチ機構です。例えば以下に示すように、ペアに対してその順序を無視してパターンマッチすることが可能です。

(define $unordered-pair
  (matcher
    {[<pair $ $> [something something]
      {[[$x $y] {[x y] [y x]}]}]
     [$ something
      {[$tgt {tgt}]}]}))

(match-all [1 2] unordered-pair {[<pair $a $b> [a b]]})
; ===> {[1 2] [2 1]}

しかし、その拡張性のためにこのパターンマッチの動作は大変複雑です。Egi, Nishiwaki (2018)はEgisonのパターンマッチの動作をbig-step styleの操作的意味論として次のように定義しています。 Egisonの操作的意味論

型健全性などのEgisonの性質を証明する際はこの操作的意味論を使うことになりますが、規則の種類、数が多いので証明を手で書くと間違えてしまいそうです。そこで今回の記事ではEgi, Nishiwaki (2018)の操作的意味論をCoqに書き直しました。

formalized-egison

formalized-egisonEgi, Nishiwaki (2018)によるEgisonの操作的意味論をCoqに書き直したものです。意味論を定義している部分はEgison.vの以下の部分です。

Inductive eval : (env * tm * tm)-> Prop :=
  | evarin : forall i Gamma t, (Gamma i) = Some t -> eval (Gamma, (tvar i), t)
  | evarout : forall i Gamma, (Gamma i) = None -> eval (Gamma, (tvar i), (tvar i))
  | eint : forall i e, eval (e, (tint i), (tint i))
  | etpl : forall Gamma t1 t2 v1 v2, eval (Gamma, t1, v1) -> eval (Gamma, t2, v2) -> eval (Gamma, ttpl t1 t2, ttpl v1 v2)
  | ecll : forall e ts vs, same_length_list ts vs ->
                      Forall eval (map (fun tpl => let '(t,v) := tpl in (e,t,v)) (zip ts vs)) -> eval (e, (tcll ts), (tcll vs))
  | epair : forall e t1 t2 v1 v2, eval (e, t1, v1) -> eval (e, t2, v2) -> eval (e, (tpair t1 t2), (tpair v1 v2))
  | esm : forall Gamma, eval (Gamma, tsm, tsm)
  (* | emtc : forall Gamma (ts: (list (pptn * tm * (list (dptn * tm))))), eval Gamma ((tmtc ts), (tmtc vs)) *)
  | etplmtc : forall Gamma t1 t2 v1 v2, eval (Gamma, t1, v1) -> eval (Gamma, t2, v2) -> eval (Gamma, ttplmtc t1 t2, ttplmtc v1 v2)
  | etmal : forall Gamma M N p L v_v v m_m m_e Delta_v, same_length_list Delta_v v_v -> eval (Gamma, M,v) -> evalmtc Gamma N [(m_m, m_e)] -> evalms3 [[([(p,m_m,m_e,v)], Gamma, empty)]] Delta_v ->
                                                  Forall eval (map (fun t => let '(d,v) := t in (Gamma @@ d, L, v)) (zip Delta_v v_v)) ->
                                                  eval (Gamma, (tmal M N (p, L)), tcll v_v)

  with evalmtc : env -> tm -> list (tm * env) -> Prop :=
  | emtcsm : forall Gamma, evalmtc Gamma tsm [(tsm, Gamma)]
  | emtcmtc : forall Gamma l, evalmtc Gamma (tmtc l) [((tmtc l), Gamma)]
  | emtctpl : forall Gamma m1 m2 n1 n2, eval (Gamma, (ttplmtc m1 m2), (ttplmtc n1 n2)) -> evalmtc Gamma (ttplmtc m1 m2) [(n1, Gamma); (n2, Gamma)]

  with evaldp : dptn -> tm -> option env -> Prop :=
  | edpvar : forall z v, value v -> evaldp (dpvar z) v (Some (z |-> v))
  | edppair : forall p1 p2 v1 v2 g1 g2,
      value v1 -> value v2 -> evaldp p1 v1 (Some g1) -> evaldp p2 v2 (Some g2) ->
      evaldp (dppair p1 p2) (tpair v1 v2) (Some (g1 @@ g2))
  | edpfail : forall t p1 p2, not (is_tpair t) -> evaldp (dppair p1 p2) t None

  with evalpp : pptn -> env -> ptn -> option ((list ptn) * env) -> Prop :=
  | eppdol : forall g p, evalpp ppdol g p (Some ([p], empty))
  | eppvar : forall i g m v, eval (g, m, v) -> evalpp (ppvar i) g (pval m) (Some ([], (i |-> v)))
  | epppair : forall pp1 pp2 p1 p2 g pv1 pv2 g1 g2,
                evalpp pp1 g p1 (Some (pv1,g1)) -> evalpp pp2 g p2 (Some (pv2,g2)) ->
                evalpp (pppair pp1 pp2) g (ppair p1 p2) (Some ((pv1 ++ pv2), (g1 @@ g2)))
  | eppvarfail : forall y g p, not (is_pval p) -> evalpp (ppvar y) g p None
  | epppairfail : forall pp1 pp2 p g, not (is_ppair p) -> evalpp (pppair pp1 pp2) g p None

  with evalms1 : ((list ms) * option env * option (list ms)) -> Prop :=
  | ems1nil : evalms1 ([], None, None)
  | ems1anil : forall sv g d, evalms1 ((([],g,d)::sv), (Some d), (Some sv))
  | ems1 : forall p m mg v av g d sv avv d1,
        evalma (g @@ d) (p,m,mg,v) avv d1 ->
        evalms1 ((((p,m,mg,v)::av, g, d)::sv), None, (Some ((map (fun ai => (ai ++ av, g, d @@ d1)) avv) ++ sv)))

  with evalms2 : (list (list ms)) -> (list env) -> (list (list ms)) -> Prop :=
  | ems2 : forall svv gvv svv1 gvv1 svv2,
      same_length_list3 svv gvv svv1 ->
      Forall evalms1 (zip3 svv gvv svv1) ->
      (filtersome gvv) = gvv1 ->
      (filtersome svv1) = svv2 ->
      evalms2 svv gvv1 svv2

  with evalms3 : (list (list ms)) -> (list env) -> Prop :=
  | ems3nil : evalms3 [[]] []
  | ems3 : forall svv gv svv1 dv gdv, evalms2 svv gv svv1 -> evalms3 svv1 dv -> gdv = gv ++ dv ->
                             evalms3 svv gdv

  with evalma : env -> ma -> list (list ma) -> env -> Prop :=
  | emasome : forall x g v d, evalma g (pvar x, tsm, d, v) [[]] (x |-> v)
  | emappfail : forall p g pp m sv pv d v avv g1,
      evalpp pp g p None -> evalma g (p,(tmtc pv),d,v) avv g1 ->
      evalma g (p,tmtc ((pp,m,sv)::pv),d,v) avv g1
  | emadpfail : forall p g pp m dp n sv pv d v pv1 d1 avv g1,
      evalpp pp g p (Some (pv1, d1)) -> evaldp dp v None ->
      evalma g (p, tmtc ((pp,m,sv)::pv),d,v) avv g1 ->
      evalma g (p, tmtc ((pp,m,(dp,n)::sv)::pv),d,v) avv g1
  | ema : forall p Gamma pp M dp N sigma_v Delta v phi1_v p1_v Delta1 Delta2 v1_vv m1_v,
      evalpp pp Gamma p (Some (p1_v, Delta1)) ->
      evaldp dp v (Some Delta2) ->
      eval (Delta @@ Delta1 @@ Delta2, N, tcll v1_vv) ->
      evalmtc Gamma M m1_v ->
      evalma Gamma (p, tmtc ((pp,M,(dp,N)::sigma_v)::phi1_v), Delta, v)
             ((map (fun tpl => match tpl with
                              | (ttpl v11 v12) => map (fun t => let '(v1, (m1, Gamma1), p1) := t in (p1,m1,Gamma1,v1)) (zip3 [v11;v12] m1_v p1_v)
                              | v11 => map (fun t => let '(v1, (m1, Gamma1), p1) := t in (p1,m1,Gamma1,v1)) (zip3 [v11] m1_v p1_v)
                              end
                   ) v1_vv)) empty.

最も重要なのはeval (Gamma, M, v)で「環境Gammaのもとで項Mが値vに評価される」という命題です。例えば、冒頭のプログラム中の(match-all [1 2] unordered-pair {[<pair $a $b> [a b]]}){[1 2] [2 1]}に評価される、というのは次のようなCoqの命題として表すことができます。

Definition unordered_pair: tm :=
    (tmtc [(pppair ppdol ppdol, ttplmtc tsm tsm,
            [(dppair (dpvar "x") (dpvar "y"), (tcll [(ttpl (tvar "x") (tvar "y")); ttpl (tvar "y") (tvar "x")]))]);
           (ppdol, tsm,
            [(dpvar "tgt", tcll [tvar "tgt"])])]).

Definition match_all_example: tm :=
    (tmal (tpair (tint 1) (tint 2)) unordered_pair (ppair (pvar "a") (pvar "b"),ttpl (tvar "a") (tvar "b"))).

Theorem unordered_pair_example : eval (empty, match_all_example, tcll [ttpl (tint 1) (tint 2);ttpl (tint 2) (tint 1)]).

また、Egisonの操作的意味論上でこの評価結果が正しいということも証明できます。 つまりunordered_pair_exampleを証明することができます。証明は長くなるので省略しますが、Egison.vの230行目以降にあります。

今後の課題

動機の項で述べたようにformalized-egisonの最終的な目標はEgisonの型安全性の証明です。Coqで表現すると

Theorem type_soundness:
    forall Gamma M T, is_typed Gamma M T => exists v, eval Gamma M v /\ is_typed Gamma v T.

となります。ただしis_typed Gamma M Tは環境Gammaの下で項Mに型Tがつくと読みます。 今後はtyped-egisonの型付け規則をCoqで書き直してis_typedを定義した上でtype_soundnessを証明する予定です。

MNISTを可能な限り高速に分類する

概要

  • MNISTの分類をする学習済みモデルを軽量化し、
  • 更にSIMD命令を使った高速化を行うことで、
  • onnxruntimeより高速なMNISTの分類が可能になりました。(シングルスレッドで)

前回までのあらすじ

この記事は前回の続きです。 前回はMNISTを分類する学習済みニューラルネットワークから不要な枝を削除し、軽量化した学習済みモデルを走らせる専用のランタイムsonnxを作ってMNIST分類の高速化を試みました。 今回はSIMD命令とマルチスレッド化による最適化でMNIST分類速度の限界に挑みます。

今回の目標タイムは既存のONNXランタイムonnxruntimeです。 この記事の各実行時間の計測は10回行い、その平均と分散を求めました。 onnxruntimeと最適化無しのsonnxの実行時間は、OS: Ubuntu19.04、CPU: Core i7 8th Gen、 メモリ: 16GiB、GPU無しの環境下で以下のようになりました。1

手法 実行時間
onnxruntime(シングルスレッド) 1.259秒 (標準偏差 0.1148秒)
onnxruntime(マルチスレッド) 0.505秒 (標準偏差 0.04249秒)
sonnx(最適化無し) 6.968秒 (標準偏差 0.08912秒)

機械学習の推論過程、特にMNISTを分類する場合、においてボトルネックになるのは重み行列を入力ベクトルに乗算する処理です。 前回は重み行列の中で絶対値の小さい要素を無視することで計算の効率化を図り、行列の要素の80%を無視して約5倍の高速化に成功しました。 しかし、80%を削減してもなお乗算はボトルネックでした。

具体的にはsonnx.cppのこの部分がボトルネックになります。

for(int i=0;i<n;i++){
        ret[B_row[i]] += B_scale[i] * x[B_column[i]];
}

SIMDによる高速化

まず、sonnxをSIMDで高速化してみます。 Core i7 8th GenではAVX2命令セットが使えるので256bitの演算を一度に行うことができ、 今回は32bit浮動小数点数で計算しているので最大8倍の高速化が見込めます。

しかし、元のコードはメモリへの間接参照を2つ(ret[B_row[i]]x[B_column[i]])含んでおりそのままではSIMD化するのが困難です。 まずret[B_row[i]]への書き込みは同じB_row[i]の値を持つものをまとめて計算し、メモリへの書き込みを一度に行います。 x[B_column[i]]からの読み出しはAVX2のgather命令を使ってSIMD化します。 完成したコードがこれです。

float r = 0;
__m256 vr = _mm256_setzero_ps();
for(cur;cur<n1;cur+=8){
    __m256i vc = _mm256_loadu_si256((__m256i*)(&B_column_p[cur]));
    __m256 vx = _mm256_i32gather_ps(x_p, vc, 4);
    __m256 vs = _mm256_load_ps(&B_scale_p[cur]);
    vr = _mm256_fmadd_ps(vs, vx, vr);
}
for(cur;cur<n;cur++){
    r += B_scale_p[cur] * x_p[B_column_p[cur]];
}
__attribute__((aligned(32))) float t[8] = {0};
_mm256_store_ps(t, vr);
ret[B_row_p[cur-1]] += r + t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7];

実行時間はこんな感じになりました。

手法 実行時間
sonnx(SIMD) 1.121秒(標準偏差 0.02271秒)
onnxruntime(シングルスレッド) 1.259秒(標準偏差 0.1148秒)

ウェルチのt検定を用いて検定すると2有意水準5%で帰無仮説が棄却され、確かにonnxruntimeより高速に推論できています。 onnxruntime(シングルスレッド)と大きな差がつかないのはgather命令が遅いからでしょうか。

マルチスレッドによる高速化

(SIMDなしの)マルチスレッド化はどうでしょうか。 オリジナルのコードはメモリへの書き込みが一箇所(ret[B_row[i]]への書き込み)しかないので、この部分だけ複数スレッドで同時に書き込まないようにすればデータレースは起こりません。 スレッド数は手元で最適なものを探索した結果4にしています。 ソースコードここにあります。

void CompressedGemm::calc_partially(const int index, const vector<float> &x){
    int m = B_nrows_threads[index].size();
    int cur = 0;
    for(int i=0;i<m;i++){
        int n = B_nrows_threads[index][i];
        float r = 0;
        for(int j=cur;j<cur+n;j++){
            r += B_scale_threads[index][j] * x[B_column_threads[index][j]];
        }
        ret[B_row_threads[index][cur]] += r;
        cur += n;
    }
}

vector<float> CompressedGemm::calc(const vector<float> &x){
    ret = C;
    vector<thread> ths;
    
    for(int i=0;i<n_thread;i++){
        ths.push_back(thread(&CompressedGemm::calc_partially, this, i, x));
    }
    for(int i=0;i<n_thread;i++){
        ths[i].join();
    }
    return ret;
}

結果はこんな感じです。

手法 実行時間
sonnx(マルチスレッド) 1.995秒(標準偏差 0.03405秒)
onnxruntime(マルチスレッド) 0.5048秒(標準偏差 0.04249秒)

完敗です。

SIMDとマルチスレッドの併用

では最後にSIMDとマルチスレッディングを併用してみます。 ソースコードここにあります。

手法 実行時間
sonnx(SIMD+マルチスレッド) 1.358秒(標準偏差 0.05762秒)
onnxruntime(マルチスレッド) 0.5048秒(標準偏差 0.04249秒)

あまり効果がありませんね...複数コアからの結果をまとめるのに時間がかかっているのかもしれません。

まとめ

影響力の小さい要素を削除して学習済みのモデルを軽量化した上でSIMDとマルチスレッド化を用いて推論の高速化を試みました。 マルチスレッド環境ではonnxruntimeに勝てませんでしたが、シングルスレッドではonnxruntimeより高速に推論できました。

参考


  1. 前回からOSを入れ替えたので数値が違います。

  2. ちょっとここ正しく議論できているか自信がない

WebAssemblyで自作言語用のGCを書く

前回の続きです。

概要

前回の記事ではWebAssemblyを出力するMLのサブセットコンパイラを作りました。しかし、WebAssemblyにはGabage Collection (GC) が未だに実装されていないため(2019/7/8時点)、メモリ管理は全て自分で行う必要があります。前回は超適当mallocで間に合わせていたのですが、今回はmalloc, freeを含むGCをWebAssembly(手書き)で作りました。

GCを含むMLコンパイラソースコードここで、GC本体はここです。

WebAssemblyの基礎知識

WebAssemblyはスタックマシンで実行される言語であり、各命令は引数をスタックのトップから取り出し結果をスタックのトップに積みます。トップ以外のスタックの中身を自由に見ることはできません1。また、スタックの他にもリニアメモリも持ち、i32.store, i32.loadといった命令でアクセスすることができます。リニアメモリは以下のような特徴を持ちます。

  • アドレスは0から始まる
  • アドレスは32bit
  • 整数又は浮動小数点数をstore/loadできる

GCの実装

GCを書く前にまずmallocとfreeが必要です。今回はfree listでブロックを管理する簡単なものを書きました。各ブロックは下図のような構造になっていて、次のブロックのアドレス、ブロックサイズ、フラグ、確保したメモリを保持します。フラグにはfree listの終端かどうかの情報とブロックが使用中かどうかの情報が入っています。それぞれ4byteずつ使うのでmallocのためのヘッダだけで12byteも使います。贅沢ですね。

+-------------------------------------------------------+
| next block address | block size | flags  | contents   |
| 4bytes             | 4bytes     | 4bytes | 4*n bytes  |
+-------------------------------------------------------+

mallocはfree listを先頭から走査して要求されたサイズ以上の空きブロックがあれば確保、最後まで走査して該当するブロックがなければ末尾に新たにブロックを作ります。freeは指定されたブロックの状態を空きブロックに変えるだけです。空きブロック同士を結合する処理はめんどくさいのでやっていません。

このmallocとfreeを使ってGCを作ります。『ガベージコレクション』によれば、GCには次の3種類があります。

  • マーク・アンド・スイープ
  • コピーGC
  • 参照カウント

このうち、マーク・アンド・スイープとコピーGCは生きているオブジェクトを探索するのにルートセットが必要です。ルートセットは一般にレジスタ、スタック、グローバル変数から構成されます。ところが、WebAssemblyではスタックの中身を自由に見ることができないためルートセットが必要なGCは採用できません2。このため今回は参照カウント方式でGCを実装しました。

gc_malloc を呼び出すとmallocが指定されたサイズ+GCのヘッダサイズ分のメモリをリニアメモリ上に確保し、GCの管理情報を書き込んだ上でそのアドレスを返します。GCの管理領域は下図のようになっていて、確保したメモリのサイズ、参照カウントの値、フラグ、オブジェクトのためのメモリ領域から構成されています。フラグには参照カウント時に行う深さ優先探索用のビットと格納しているオブジェクトが値(整数値又は浮動小数点数)か否かの情報が含まれます。

+----------------------------------------------------+
| memory size | reference count | flags  | contents  |
| 4bytes      | 4bytes          | 4bytes | 4*n bytes |
+----------------------------------------------------+

これがmallocのcontentsの中に入っているので実際にはこのようになります。

+---------------+------------+--------+--------------------------------------------------------+
| next malloc   | malloc     | malloc | malloc contents                                        |
| block address | block size | flags  | +----------------------------------------------------+ |
| 4bytes        | 4bytes     | 4bytes | | memory size | reference count | flags  | contents  | |
|               |            |        | | 4bytes      | 4bytes          | 4bytes | 4*n bytes | |
|               |            |        | +----------------------------------------------------+ |
+---------------+------------+--------+--------------------------------------------------------+  

参照カウントはコンパイラが変数のスコープに応じてオブジェクトの参照カウントを増減させるコード(gc_increase_rc, gc_decrease_rc)を挿入することで操作します。基本的にオブジェクトの作成時に参照カウントが1増加し、そのオブジェクトを指す変数のスコープが終了したところで参照カウントが1減らします。参照カウントが0になるとそのオブジェクトは解放され、またそのオブジェクトが保持していたアドレスが指すオブジェクトの参照カウントを1減らします。この操作は再帰的に行われます。ただし、解放されたオブジェクトが値だった場合はこの再帰的な操作は行われません。

gc_increase_rc, gc_decrease_rcはこんな感じで挿入される

(i32.const 1)
(i32.store)))
(; let fun_fib_0 end ;)
(call $gc_increase_rc)
(get_local $val_b_6)
(call $gc_decrease_rc)
(drop)
(get_local $val_a_5)
(call $gc_decrease_rc)
(drop)

基本的に参照カウントの操作は変数のスコープと連動させればよいのですが、関数終了時には特殊な処理が必要になります。関数内のローカル変数が指すオブジェクトは関数の終端で解放されますが、関数の戻り値が開放されるオブジェクトを含む場合、すでに解放されたオブジェクトが関数の戻り値に含まれることになります。このようなバグを防ぐためには、関数内のローカル変数を開放する処理を行う前に戻り値オブジェクトの参照カウントをインクリメントしておく必要があります。

その他にも配列の要素を書き換えるときは前のオブジェクトの参照カウントを減らしてから、新しく代入するオブジェクトの参照カウントをインクリメントする必要があるなど、細かいところでいろいろハマりました。

結果

フィボナッチ数列を第10項まで計算するプログラムでプログラム終了時のメモリ使用量をGC有りの場合と無しの場合で結果を比較してみると、以下のようになりました。使用メモリの4分の3ぐらいを解放できているのでそれなりに(僕が)満足できる結果です。

- メモリ使用量
GCなし 22868byte
GCあり 4660byte
let rec fib n = 
  if n < 3 
    then 1
    else 
        let a = fib (n - 1) in
        let b = fib (n - 2) in
        a + b in
let rec print x =
  if 10 < x
    then 1
    else
      let a = fib x in
      print_i32 a;
      print (x+1) in
print 0

参考リンク

余談

  • コンパイル時にWebAssemblyで書いたGCをリンクする必要があるのですが、スマートな方法が見つからずテキスト表現のまま無理やりリンクしています。lldを使えばいいのではないか?という情報を頂いたのですが、WebAssemblyのテキスト表現をELF形式に変換する方法がわかりませんでした。
  • WebAssemblyの直書きは辛い...バグると何も言わずに止まっちゃうし。
  • 公式の人は早くGCを組み込んで欲しい。やっぱりスタックを自由に走査できない環境でGCを書くのは無理がある。

  1. 現在のスタックの長さを記録しておけばリニアメモリにスタックの中身をすべて書き出すことはできる。ただしstore/load命令が整数と浮動小数点数で異なるので、スタックに積まれている全ての値について整数なのか浮動小数点数なのかという情報も保存しておく必要がある。

  2. 注1のようにスタックの中身をリニアメモリに書き出せれば、ルートセットが必要な方式も採用できる。

ONNXファイルから不要な枝を削ってMNISTの推論を高速化してみる

この記事の中のソースコードは全てhttps://github.com/akawashiro/sonnxにあります。

概要

背景

機械学習の学習済みモデルを小さなデバイスで動かす、というのが最近流行っているそうです。機械学習では、学習には大きな計算コストがかかりますが、推論はそれほど大きな計算コストがかかりません。このため、学習だけを別のコンピュータで行っておいて、実際の推論は小さなデバイスで行うということが可能です。

ただし、推論だけでもそれなりに計算資源が必要です。そこで、学習済みのモデルの高速化が重要になります。Raspberry Piに搭載されているGPUを使うIdeinとか有名です。

僕も学習済みモデルの推論を高速化できそうな方法を思いついたので実験してみました。

イデア

今回はMNISTを分類する学習済みモデルを高速化します。今回使用するモデルは次の図のようなものです。画像は28*28(=784)pxなので入力は784個、出力は各数字の確率なので10個あり、中間層が2つ挟まっています。各層間は全結合しており、活性化関数としてReluを使います。

このモデルを教師データを使って学習すると、枝の重みが変わってこんな感じになります。

僕のアイデアは学習後のネットワークから重みの小さい枝を取り去ってもちゃんと動くんじゃないか、というものです。重みの小さい枝を取り去るとこんな感じになります。

枝の本数が少なくなれば、推論を高速化できそうな気がします。

イデアの裏付け

実際に学習後のモデルで赤丸で囲った部分の重みの分布を確認します。

分布はこのようになっています。

重みが0の部分が非常に多いです。まず、学習済みモデルから重みが0の枝を削除しても推論結果に影響しないはずです。また、グラフが左右対称になっているので、絶対値の小さい順に削除していけば、各パーセプトロンへの入力はそれほど変化しない気がします。

手法

やることは非常に単純です。ニューラルネットワーク中の各層間枝の重みの統計を取り、重み上位何%かを残して残りを削除します。つまり、何らかの方法で学習済みのモデルから枝の重みのデータを取り出し、枝をカットし、さらに加工後のモデルのデータを使って推論できるようにします。

幸い今はONNXという良いものがあります。ONNXとは学習済みモデルのデータを出力する形式の一つで、多くのフレームワークが対応しています。

今回はChainerで書いたモデルからONNXデータを出力し、そのデータを加工することにしました。また加工後のデータはC++で書いた俺俺ONNXランタイムに実行してもらうことにしました。

纏めると、
1. Chainerでニューラルネットワークを書いて学習する
2. 学習済みのニューラルネットワークからONNXデータを出力する
3. ONNXデータを俺俺ONNXランタイムに読み込んで加工、実行する

となります。一つづつ何をやるのかを説明します。

1. Chainerでニューラルネットワークを書いて学習する

2. 学習済みのニューラルネットワークからONNXデータを出力する

1と2は簡単です。onnx-chainerを使えばすぐにできます。

python3 learn_mnist.py

mnist.onnxというファイルができます。

3. ONNXデータを俺俺ONNXランタイムに読み込んで加工、実行する

ここが大変でした。ディープラーニングフレームワークからのONNXモデルの出力は多くの人が試しているのですが、出力したONNXモデルをチューニングしようとする人はほとんどいないようです。

3.1 ONNXデータを解析する

とりあえずnetronというONNXの可視化ツールでmnist.onnxを可視化してみました。

GemmはGeneral matrix multiplyの略です。各Gemmノードは行列BとベクトルCを持ち、ベクトルxを入力としてBx+Cを出力します。Relu活性化関数です。

Reluはmax(0,x)で定義されている関数のでONNXから抽出する必要は無く、各Gemmノードの行列BとベクトルCの情報だけをを抽出できれば良いです。

今回は各GemmノードのBCをテキストファイルとして抽出します。

> python3 analyze_mnist_onnx.py

は次のファイルを出力します。

  • *************_matrix.txt
    mnist.onnxの全てのGemmノードのB行列とC行列
  • *************_matrix.png
    各行列の中の重みの分布
  • mnist.onnx.json
    ONNXファイルをJSONに変換したもの
  • mnist_train.txt
  • mnist_test.txt
    MNISTをC++から読み込みやすくするためにテキストファイルに変換したもの

*************_matrix.txtがどのGemmノードに対応するのかはmnist.onnx.jsonを睨むとわかります(←ここ超不親切)。

3.2 抽出したデータを加工、実行する
g++ -O3 -mtune=native -march=native -mfpmath=both sonnx.cpp && ./a.out

で出力した重みのデータを読み込み、推論を実行します。sonnx.cppは簡単なONNXランタイムになっており、mnist.onnxから抽出した行列のデータとMNISTの画像データのテキストファイルから推論を行います。デフォルトではmnist_test.txtの10000枚について推論を行います。

出力はこのようになります。

accuracy, time
0.9817000031, 15.93887043
compress_ratio, accuracy, time
0, 0.9817000031, 36.66616821
0.05, 0.9817000031, 34.61940384
0.1, 0.9818000197, 32.64707184
0.15, 0.9818000197, 30.78090286
0.2, 0.9817000031, 29.01914406
..........................

2行目は読み込んだデータを加工せずで隣接行列表現で保持したときの推論の精度と計算時間、4行目以降は読み込んだデータを加工したときの推論の圧縮率(compress_ratio)、精度と計算時間です。圧縮率が0のときは全く加工しないのと同じ、圧縮率が0.3のときは枝の重みの絶対値が小さい30%を削除したときの結果です。圧縮率が1になると全ての枝が削除されます。

データを加工して枝を削除するとき、sonnx.cppでは行列データを非零成分のインデックスとその値の組の配列として保持しています。これは枝を削除したときに保持するデータ量が減らし、高速に計算するためです。

結果

圧縮率を変化させたときの精度と計算時間のグラフです。圧縮率を0.8まで上げても推論の精度が変わっていません。これは学習済みモデル中の8割の枝を削除しても、推論精度が保てるということです。驚きですね!

一方、計算時間はほぼ枝の数に反比例しています。これは予想通りでした。

表にしてみます。推論精度を1%犠牲にするだけで2倍も高速化できています。やったね!

圧縮率 推論精度 計算時間(秒)
0 0.981 15.9
0.8 0.971 7.04

※ 圧縮率0のときの値は重みのデータを隣接行列表現で保持したときのものです。

おまけ

最後にonnxruntimeで元の学習済みモデルを実行したときの時間を計ってみます。

> python3 analyze_mnist_onnx.py
Accuracy rate =  0.9817 , time =  3.285153865814209

え、3秒?? 普通に考えると16秒ぐらいになるはずです。なぜこんなに速いんだ...

理由は色々あると思いますが、僕が思いつくのは以下の2つです。

  • AVXなどのSIMD拡張命令を使っている
    (僕もAVXを使おうとしましたが、scatter命令が僕のPCで使えないので諦めました。)
  • CPUのキャッシュに乗るようなプログラムになっている

関連記事・研究

まとめ

  • ONNXランタイムを自作したらナイーブな実装の二倍の速度で推論ができた
  • onnxruntimeには勝てなかった

高階関数を基本的な関数の合成で作った関数でQuickCheckする

高階関数をQuickCheckでテストしてみる

QuickCheckを知っていますか?
QuickCheckと言うのはHaskellのデータ駆動型のテスト用ライブラリで
テストしたい関数を指定するとその引数に合わせて適当なテストデータを生成してくれます。

では高階関数(関数を引数に取る関数)をQuickCheckでテストすると
どんなテストデータ(関数)を生成してくれるのでしょうか。
ちょっと確認してみましょう。

-- QCTest.hs
import Test.QuickCheck
import Test.QuickCheck.Function

prop :: Fun Integer Integer -> Bool
prop f = apply f 10 == apply f 20

main = quickCheck prop
% stack runghc QCTest.hs
*** Failed! Falsifiable (after 2 tests and 12 shrinks):
{20->0, _->1}

...なんだこの関数は?
入力と出力の組が列挙されてるだけやんけ。
僕の思ってる関数と違うんですが。
map (+2)とかtailとかそういうやつを生成してほしいなぁ。

まともなテスト用関数を生成する

QuickCheckのテストデータが気に入らないので、まともな(主観)テストデータを生成してみました。
今回生成するテストデータは[Int] -> [Int]の関数に限定します。

そもそも関数を生成するってどうやればいいんでしょうか?
QuickCheckのように入力と出力の対を列挙するのは筋が悪い気がします。
mapや(+)、tailといった基本的な関数を合成して関数を生成したいのです。
我々がプログラミングするときってそうやりますよね。
入力と出力の組を列挙する人はあんまりいないと思います。

[Int] -> [Int]型の関数を列挙してみます。

tail
reverse
map (* 10)
map (+ 10)

tailやreverseはそもそも[Int] -> [Int] 型です。
mapは(Int -> Int) -> [Int] -> [Int] 型なので
(+ 10) :: Int -> Intを渡すと[Int] -> [Int]型になります。

こんな風にどの関数(例 map)にどの関数(例 (+ 10))を適応すれば
どんな型(例 [Int] -> [Int])になるかを規則として書き出します。
型付け規則っぽく書くとこんな感じです。
Imgur
ちょっとわかりにくいかも知れません。

[Int] -> [Int]型の関数を2つ作ってみました。
Imgur
左はmap (+ 10)のような関数、右はmap (+ 10 (* 5))のような関数を表しています。
一番下の型にある関数の規則を持ってきて、上に必要な型を書いて
さらにその必要な型をもつ関数の規則を持ってきて...という感じです。

とりあえずこの規則を使って関数を生成できるようになりました。
[Int] -> [Int]型の関数を生成した例がこちらです。

% stack exec qcfun-exe
> Test datum are following.
> ["tail","reverse","(map (* 75))","tail","reverse"]
...

今見せた例はこのようなデータ型で表現されています。

data QProg = QMap1 QProg | QMap2 QProg QProg 
             | QTail1 QProg | QTail | QRev | QMult QProg | QAdd QProg | QRand Int deriving (Eq)

instance Show QProg where
  show (QMap1 p) = "(map " ++ show p ++ ")"
  show (QMap2 p1 p2) = "(map " ++ show p1 ++ " " ++ show p2 ++ ")"
  show (QTail1 p) = "(tail " ++ show p ++ ")"
  show (QTail) = "tail"
  show QRev = "reverse"
  show (QMult p) = "(* " ++ show p ++ ")"
  show (QAdd p) = "(* " ++ show p ++ ")"
  show (QRand i) = show i

QProg型がプログラム(関数)を表す型です。
このQProg型に対して適当なShowインスタンスを定義してあげると関数っぽく見えるようになります。

作ったテストデータ用の関数で高階関数をテストする

次に作った関数を使って高階関数をテストしてみましょう。
今回テストする対象のプログラムはこれです。

func f = reverse . f
prop f = (func f [1,2,3]) == [1,2,3]

関数fを引数に取る簡単な高階関数funcの性質をテストします。

しかし作ったテストデータ用の関数はすべてQProg型です。
つまりQProg型をなんとかして[Int] -> [Int]型に変換したうえで
テストしたい高階関数に適用する必要があります。

今回はTemplate Haskellを使ってQProg型を[Int]->[Int]型のHaskellの関数に変換しました。
Template Haskellというのはプログラムの中でプログラムを作るためのツールです。
Lensなんかで使われています。 動作例がこちらです。

% stack exec qcfun-exe
...
Test results are following.
[(False,"tail"),(True,"reverse"),(False,"(map (* 75))"),(False,"tail"),(True,"reverse")]

テストデータは先頭から順に["tail","reverse","(map (* 75))","tail","reverse"]だったので正常に動いているようです。

まとめ

高階関数をまともな(主観)関数でテストできるようにしました。
しかし技術的な制約のためテストデータをコンパイル時に生成するので
テストデータを完全にランダムに生成することはできませんでした。

参考

http://haskell.g.hatena.ne.jp/mr_konn/20111218/1324220725

付録

githubへのリンクです
https://github.com/akawashiro/qcfun
以下のコマンドで実行できます。

git clone https://github.com/akawashiro/qcfun.git
stack build
stack exec qcfun-exe

ソースコードはTemplate Haskell的な事情で2つにわかれています。

WebAssemblyを出力するMinCamlコンパイラを実装しました

概要

WebAssemblyを出力するMinCamlコンパイラml2wasmフルスクラッチで実装しました。
github.com

マンデルブロ集合を計算するこんな↓感じのMinCamlソースコード

f:id:a_kawashiro:20181029180551p:plain:h325,w500
マンデルブロ集合を出力するMinCamlソースコード

こんな↓感じのWebAssemblyに変換されて

f:id:a_kawashiro:20181029183757p:plain:h325,w500
出力されたWebAssembly

実行して、適切にプロットするとこんな↓感じになります。

f:id:a_kawashiro:20181029180607p:plain:h325,w500
作成したコンパイラコンパイルし実行・プロットしたもの

導入

WebAssemblyとは

WebAssemblyとはブラウザで実行可能な低級プログラミング言語であり、仮想的なスタックマシン上で動作します。 また、JavaScriptに比べて構文解析が容易であり高速に動作します。 WebAssemblyは最近のほとんどのブラウザで動かすことができます。*1

;; WebAssembly のテキスト表現の例
(module
  (func (export "add") (result i32)
    (i32.add
      (i32.const 1)
      (i32.const 3))))

MinCamlとは

MinCamlとは東北大学の住井先生*2が設計したプログラミング言語MLのサブセットです。 言語仕様が小さく、処理系を制作するのが比較的容易なのが特徴です。 住井先生は未踏事業としてMinCamlコンパイラも制作されています。 このコンパイラは読みやすくドキュメントもしっかりしているので、コンパイラを作ってみたい人には非常にオススメです。 ml2wasmの開発でもこのコンパイラを大いに参考にしました。

(* MinCaml の例 *)
let rec add = 1 + 3 in add

コンパイラの作成

今回はMinCamlソースコードをWebAssemblyのテキスト表現に変換するコンパイラml2wasmを作成しました。 基本的にml2wasmはオリジナルのMinCamlコンパイラと同じ構造ですが、WebAssemblyがスタックマシンであることを利用していくつかの変換の工程が省略されています。

ml2wasmの実装の詳細

ml2wasmでは次の5段階の工程でMinCamlをWebAssemblyに変換します。

  1. パース
  2. アルファ変換
  3. 型推論
  4. クロージャ変換
  5. コード生成

特に説明のない部分はオリジナルのMinCamlコンパイラと同じです。 今回はHaskellで実装しましたが、もともとの実装言語がOCamlHaskellとそれほど変わらないので、その点は苦労しませんでした。 以下、WebAssembly特有のコンパイラ実装事情をいくつか説明します。

オリジナルのMinCamlコンパイラには3と4の間にK正規化、4と5の間にレジスタ割当という処理がありましたが、ml2wasmではこれら2つの処理を省略しています。 K正規化はWebAssemblyでは命令のオペランドとしてスタックの先頭の値を使い、オペランドレジスタに格納する必要がないため、 レジスタ割当はWebAssemblyにはレジスタが存在しないため、それぞれ省略できました。

逆に、WebAssembly特有の難しい工程としてはコード生成があります。 通常、関数型言語コンパイラではプログラムをクロージャ変換し、関数から自由変数を削除します。 その後コード生成の工程で関数呼び出しは間接ジャンプ(MIPSならjr命令)に変換されます。

この間接ジャンプをWebAssemblyではcall_indirect命令で実現します。

;; call_indirect命令はこんな感じで呼び出される。
...
(i32.load)
(call_indirect (param f32) (param f32) (param i32) (result i32)))))
...
;; (param f32) (param f32) (param i32) (result i32) ← 型が必要

ここで問題になるのがcall_indirect命令が呼び出す関数の型をオペランドとして要求する点です。 これはクロージャ変換後の中間言語に型情報を残しておく必要があることを意味します。 コンパイラを書き始めたときにはこのことに気づかなかったので、あとから型情報を追加するのに苦労しました。

他には公式のリファレンスがどこにあるのかイマイチわからない問題もあります。 結局これを使ったのですが、 URLに公式感がなく最後までこのリファレンスが公式のものなのかはわからずじまいでした。

結果

マンデルブロ集合を出力するプログラムをコンパイルできるようになりました! やったね!

マンデルブロ集合を出力するMinCamlソースコード
mandelbrot.ml · GitHub
コンパイル結果
mandelbrot.wast · GitHub

ちなみにマンデルブロ集合っていうのはこんな↓やつです。

f:id:a_kawashiro:20181031155626j:plain:h325,w500
本当はこんな感じになって欲しかった

今回作ったコンパイラコンパイルし実行した結果(を適切にプロットしたもの)は↓

f:id:a_kawashiro:20181029180607p:plain:h325,w500
作成したコンパイラコンパイルし実行・プロットしたもの

まあ大体あってますね。点の数が以上に少ないのは大量の点をプロットしようとするとすぐにメモリ不足に陥るからです。

今後の課題

ml2wasmで出力したWebAssemblyコードはすぐにメモリ不足に陥ります。 理由は単純でGCがないからです。 WebAssemblyでGCを実装する試みとしては κeenさんのWebAssemblyでGC | κeenのHappy Hacκing Blog がありますが、今回はめんどくさいのでメモリは確保しっぱなしです。

そのうちWebAssemblyにデフォルトでGCが入るとかいう噂も聞くので、そのときにGCを入れればいいやと思っています。

参考リンク

ネットワークデバイスドライバを一からビルドしてインストールした

概要

新し目のコンピュータにDebianを入れたら、ネットワークデバイスドライバが入ってなくて結構大変だった。

経緯

最近、自分の使っていたコンピュータ(6年前くらいのデスクトップ)に限界を感じ始めたので、日本橋で新しいものを買ってくることにした。購入に当たっては、換装が難しいCPUをCorei7 8700kに決定し、フィーリングで選ぶことにした。日本橋を端から端まで歩き回って、結局LM-iH700XD1-EX4を買った。

f:id:a_kawashiro:20180928082938g:plain

Windows10をDebian streachで上書きして起動すると、ネットワークドライバが入ってないことがわかった。マザーボードが新しすぎるのでドライバがDebianに取り込まれていないのだろう。こういう場合は自前でドライバをビルドする必要がある。

NIC(Network Interface Card)の型番を調べる

まず、NICの型番を調べる。商品のスペックを確認したが、「LAN:10/100/1000BASE-T LAN(オンボード)」としか書いていない。マザーボードの方には「マザーボード:Intel B360 Micro ATX LGA1151」と書いてあった。オンボードNICなのだから、たぶん「Intel B360」のほうが関係しているのだろうと考えられる。ここから「Intel B360 network driver」などで検索して一時間ほど頑張ったが、なんの成果も得られなかった。

次はLinux側からNICを調べてみる。

% lspci
(省略)
00:1f.6 Ethernet controller: Intel Corporation Device 15bc (rev 10)

なるほど、「Intel Corporation Device 15bc linux driver」で検索すれば良さそうである。検索すると、Linux Kernel Driver DataBaseがトップに出てきた。「15bc 」でページ検索すると

vendor: 8086 ("Intel Corporation"), device: 15bc ("Ethernet Connection (7) I219-V")

が出てくる。「I219-V linux driver」で検索してみると「e1000e」の最新のドライバを入れればいいらしい?

そこで、もうどこで見たのか忘れたけど Intel Ethernet Drivers and Utilities の最新版をインストールすればいいらしい。(検索しすぎて記憶がないorz)

makeを入れる

というわけでe1000e-3.4.2.1.tar.gzをビルドする。 ネットワークに繋がらないので別のパソコンでファイルをダウンロードしてUSBメモリでコピーした。

展開して、make installすると...makeコマンドが入ってないと言われた。 Debianのインストール元のlive USBをマウント後、apt install build-essentialで入れた。 最初から入れといてほしい...

デバイスドライバをビルドする

e1000e-3.4.2.1.tar.gzを展開して、make installするとカーネルのヘッダファイルが足りないと言われた。

% uname -a
Linux gley 4.9.0-8-amd64 #1 SMP Debian 4.9.110-3+deb9u4 (2018-08-21) x86_64 GNU/Linux

4.9.0-8-amd64のヘッダファイルが必要なようだ。別のパソコンで該当のdebファイルを落としてきてdpkg -iで入れる。ちなみにcommonamd64の両方が必要だった。

その後、再度make install。 再起動するとネットワークに繋がっていた。

まとめ

新し目のパソコンでLinuxを動かすのは大変。 ただドライバをビルドしてインストールする作業は初めてで勉強になった。 もうやりたくない。

ちなみにオーディオデバイスドライバが入っていないので音は出ない。 そのうち頑張る。