setjmp, longjmpで使われるjmp_bufには何がどのように保存されているのか

setjmp, longjmpはC言語で実行コンテキストを保存し、保存したコンテキストに復帰するために使われる関数(orマクロ)である。g++の例外実装の一つはこのsetjmp, longjmpを用いて実現されている。

この記事では 実行コンテキスト とは何を指すのか? また、その保存方法の実装について調べる。

jmp_buf

setjmpの実装はglibcsetjmp/setjmp.h にある。

typedef struct __jmp_buf_tag jmp_buf[1];

/* Store the calling environment in ENV, also saving the signal mask.
   Return 0.  */
extern int setjmp (jmp_buf __env) __THROWNL;

このjmp_bufsetjmp/bits/types/struct___jmp_buf_tag.h で定義されている。

/* Calling environment, plus possibly a saved signal mask.  */
struct __jmp_buf_tag
  {
    /* NOTE: The machine-dependent definitions of `__sigsetjmp'
       assume that a `jmp_buf' begins with a `__jmp_buf' and that
       `__mask_was_saved' follows it.  Do not move these members
       or add others before it.  */
    __jmp_buf __jmpbuf;     /* Calling environment.  */
    int __mask_was_saved;  /* Saved the signal mask?  */
    __sigset_t __saved_mask;    /* Saved signal mask.  */
  };

x86-64の場合、__jmp_bufsysdeps/x86/bits/setjmp.hで定義されている。__jmp_bufは8個のint型整数からなり、更に64bitかどうかで場合分けがなされていることがわかる。

# if __WORDSIZE == 64
typedef long int __jmp_buf[8];
# elif defined  __x86_64__
__extension__ typedef long long int __jmp_buf[8];
# else
typedef int __jmp_buf[6];
# endif

setjmp

setjmp自体の実装はsysdeps/x86_64/setjmp.Sにある。Linux x86-64では第一引数を%rdiで渡すので、第一引数として渡された構造体にレジスタの値を保存していることがわかる。保存されているのは、%rbx, %r12, %r13, %r14, %r15, %rsp, PCである。しかし、%rsp, %rbp, PCはPTR_MANGLEで処理されてから保存されている。PTR_MANGLEとは何だろうか? また、なぜ必要なのだろう?

ENTRY (__sigsetjmp)
    /* Save registers.  */
    movq %rbx, (JB_RBX*8)(%rdi)
#ifdef PTR_MANGLE
# ifdef __ILP32__
    /* Save the high bits of %rbp first, since PTR_MANGLE will
       only handle the low bits but we cannot presume %rbp is
       being used as a pointer and truncate it.  Here we write all
       of %rbp, but the low bits will be overwritten below.  */
    movq %rbp, (JB_RBP*8)(%rdi)
# endif
    mov %RBP_LP, %RAX_LP
    PTR_MANGLE (%RAX_LP)
    mov %RAX_LP, (JB_RBP*8)(%rdi)
#else
    movq %rbp, (JB_RBP*8)(%rdi)
#endif
    movq %r12, (JB_R12*8)(%rdi)
    movq %r13, (JB_R13*8)(%rdi)
    movq %r14, (JB_R14*8)(%rdi)
    movq %r15, (JB_R15*8)(%rdi)
    lea 8(%rsp), %RDX_LP    /* Save SP as it will be after we return.  */
#ifdef PTR_MANGLE
    PTR_MANGLE (%RDX_LP)
#endif
    movq %rdx, (JB_RSP*8)(%rdi)
    mov (%rsp), %RAX_LP /* Save PC we are returning to now.  */
    LIBC_PROBE (setjmp, 3, LP_SIZE@%RDI_LP, -4@%esi, LP_SIZE@%RAX_LP)
#ifdef PTR_MANGLE
    PTR_MANGLE (%RAX_LP)
#endif
    movq %rax, (JB_PC*8)(%rdi)

#ifdef SHADOW_STACK_POINTER_OFFSET
# if IS_IN (libc) && defined SHARED && defined FEATURE_1_OFFSET
    /* Check if Shadow Stack is enabled.  */
    testl $X86_FEATURE_1_SHSTK, %fs:FEATURE_1_OFFSET
    jz L(skip_ssp)
# else
    xorl %eax, %eax
# endif
    /* Get the current Shadow-Stack-Pointer and save it.  */
    rdsspq %rax
    movq %rax, SHADOW_STACK_POINTER_OFFSET(%rdi)
# if IS_IN (libc) && defined SHARED && defined FEATURE_1_OFFSET
L(skip_ssp):
# endif
#endif
#if IS_IN (rtld)
    /* In ld.so we never save the signal mask.  */
    xorl %eax, %eax
    retq
#else
    /* Make a tail call to __sigjmp_save; it takes the same args.  */
    jmp __sigjmp_save
#endif
END (__sigsetjmp)

PTR_MANGLEは sysdeps/unix/sysv/linux/x86_64/sysdep.h に定義がある。

#  define PTR_MANGLE(reg)    xor %fs:POINTER_GUARD, reg;           \
                rol $2*LP_SIZE+1, reg
#  define PTR_DEMANGLE(reg) ror $2*LP_SIZE+1, reg;                \
                xor %fs:POINTER_GUARD, reg

Linux x86-64では%fsレジスタがThread Local Storageのベースポインタとして使われることを思い出すと、このコードは

  • スレッド固有の値(%fs:POINTER_GUARD)とのxorを取り
  • $2*LP_SIZE+1だけレジスタのビットを回転させる

ことがわかる。PTR_MAGNLEの詳細は Pointer Encryptionにある。

なぜsetjmp, longjmpに際してポインタを暗号化するのだろうか? 最近のglibcではatexit関数やjmp_bufを狙った攻撃は効かない (PTR_MANGLE)Pointer Subterfugeによれば、書き込み可能な領域に生のEIPを保存すると攻撃対象になるからだそうだ。

結論

  • setjmpは%rbx, %r12, %r13, %r14, %r15, %rsp, PCをjmp_bufに保存する
  • このとき、生のEIPを保存するとそのデータ自体が攻撃目標になるため、PTR_MANGLEで暗号化する

検索エンジンを自作する夢

そろそろ学生も終わるので、1年前の失敗と反省を書き留めておこうと思います。私達(@a_kawashiro@gky360)は2018年度の未踏プロジェクトで分野限定型検索エンジンを複数組み合わせた分散型検索エンジンとして採用され、9ヶ月間を検索エンジンの作成に費やしました。この記事ではプロジェクトの紹介と、作成した検索エンジンが満足な性能を出せなかった原因について述べます。

膨大なインターネットの情報の海の中で目的の情報に辿り着くためには、検索エンジンが必要不可欠です。しかし世の中のメジャーな検索エンジンは大企業が提供してくれているものばかりで、ユーザが自分で管理・設定できるものはほとんどありません。それなら一つ作ってみるか、ということで始めたのがこのプロジェクトでした。

検索エンジンは大まかに言って3つの機能で構成されます。

  • クロール: インターネット上からウェブページをダウンロードして情報を蓄積する
  • 検索: 蓄積した情報のなかからユーザのクエリにマッチするものを探し出す
  • 表示: マッチするものを適切に並び替えて表示する

どれも最低限の実装をするならそれほど難しくないので、検索エンジンを作るのはそれほど難しくないように思われます。

しかし、現在のインターネットはあまりに巨大です1千里眼やMondouが作られた時代はせいぜい300万サイトしかありませんでしたが、今は16億サイトを超えていて素直な方法で検索エンジンを作ってもまともな結果が帰ってこないのは明らかです。

Googleのような巨大なデータセンターを使えればよいのですが、お金に余裕がなかったので、インターネット上の個人のサーバをうまく協力させて一つの検索エンジンを構成することにしました。このような検索エンジンは分散検索エンジンと呼ばれ、yacyなどがよく知られています。yacyは完全なpeer to peer型の分散検索エンジンですが、私達は

  1. 数多くのwebページを分野ごとに分類し各分野ごとの検索エンジンを作り、
  2. 分野ごとの検索エンジンを複数つなぎ合わせることでより大きな検索エンジンを作る

という設計を採用することにしました。この設計のメリットは、

  • 各分野ごとの検索エンジンはその分野に特化することで、大きなリソースを使わずに精度の高い検索サービスを提供でき、
  • 更にそれらを複数つなぎ合わせることでより広範囲の検索サービスを段階的に構築できる

の2点です。完全なpeer to peer型の分散検索エンジンに比べて、ある程度の階層構造があることで検索精度の改善が期待できます。

この設計を元に実装したのが、kearchです。kearchには2つの検索エンジンの実装が含まれます。一つは専門検索エンジン(各分野ごとの検索エンジン)、もう一つはメタ検索エンジン(専門検索エンジンをつなぎ合わせるための検索エンジン)です。

ユーザからのクエリはメタ検索エンジンによって適切な専門検索エンジンに割り振られ、ユーザはその専門検索エンジンの結果をメタ検索エンジンを通して受け取ります。例えばこの検索エンジンの構成で、"pip3 install flask"と検索するとPythonの専門検索エンジンサーバにクエリが飛んでPythonに関する検索結果が返ってきます。

f:id:a_kawashiro:20200304225906p:plain
kearchを使った検索エンジンの構成例

2019年1月から3月にかけて、私達はさくらクラウド上で上の図の検索エンジンを実際に運用してみました。確かに専門検索エンジンメタ検索エンジンも動くのですが実用には程遠い検索性能でした。その理由は以下の2つです。

まず、1つ目の「専門検索エンジンの数が足りない」ですが、これはつまり、人間が日常的に検索する分野の範囲が広すぎるということです。ブラウザの履歴を確認してもらうとわかるのですが、1人の人間が1ヶ月で検索する分野は非常に広大です。明らかに4台では足りませんでした。この問題に対する解決策は未だに思いついていません。

2つ目の「専門検索エンジンのクロール性能が低い」は具体的にはwebページの分類性能が低いということです。現在のkearchでは該当分野の単語リストを使ってwebページを分類するのですが、単語リストだけでは十分な性能が出ません。該当分野と全く関係ないwebページがクロールされていることもよくありました。

2つ目の問題にはいくつかの解決策が考えられます。例えば、該当分野の論文を読み込ませて単語リストを拡充したり、信頼できるwebページからのホップ数を考慮することで分類精度を改善できるでしょう。またFessなどのエンタープライズサーチ向けの製品を専門検索エンジンとして使えるように改造するのも解決策の一つだと思います。

最後に、専門検索エンジンについては様々な問題を感じる一方、メタ検索エンジンについてはさほど問題を感じませんでした。「小さな検索エンジンをつなぎ合わせてより大きな検索エンジンを作る」という基本設計は間違っていないと今でも思っています。

formalized-egison -- Egisonの型安全性の証明に向けて

概要

Egisonのパターンマッチの動作を定理証明支援系Coq上で定義し、「Egisonの処理系が項Mを値vに評価する」という命題をCoq上で表現できるようにしました。

動機

Egisonの中核的な機能はユーザが拡張可能なパターンマッチ機構です。例えば以下に示すように、ペアに対してその順序を無視してパターンマッチすることが可能です。

(define $unordered-pair
  (matcher
    {[<pair $ $> [something something]
      {[[$x $y] {[x y] [y x]}]}]
     [$ something
      {[$tgt {tgt}]}]}))

(match-all [1 2] unordered-pair {[<pair $a $b> [a b]]})
; ===> {[1 2] [2 1]}

しかし、その拡張性のためにこのパターンマッチの動作は大変複雑です。Egi, Nishiwaki (2018)はEgisonのパターンマッチの動作をbig-step styleの操作的意味論として次のように定義しています。 Egisonの操作的意味論

型健全性などのEgisonの性質を証明する際はこの操作的意味論を使うことになりますが、規則の種類、数が多いので証明を手で書くと間違えてしまいそうです。そこで今回の記事ではEgi, Nishiwaki (2018)の操作的意味論をCoqに書き直しました。

formalized-egison

formalized-egisonEgi, Nishiwaki (2018)によるEgisonの操作的意味論をCoqに書き直したものです。意味論を定義している部分はEgison.vの以下の部分です。

Inductive eval : (env * tm * tm)-> Prop :=
  | evarin : forall i Gamma t, (Gamma i) = Some t -> eval (Gamma, (tvar i), t)
  | evarout : forall i Gamma, (Gamma i) = None -> eval (Gamma, (tvar i), (tvar i))
  | eint : forall i e, eval (e, (tint i), (tint i))
  | etpl : forall Gamma t1 t2 v1 v2, eval (Gamma, t1, v1) -> eval (Gamma, t2, v2) -> eval (Gamma, ttpl t1 t2, ttpl v1 v2)
  | ecll : forall e ts vs, same_length_list ts vs ->
                      Forall eval (map (fun tpl => let '(t,v) := tpl in (e,t,v)) (zip ts vs)) -> eval (e, (tcll ts), (tcll vs))
  | epair : forall e t1 t2 v1 v2, eval (e, t1, v1) -> eval (e, t2, v2) -> eval (e, (tpair t1 t2), (tpair v1 v2))
  | esm : forall Gamma, eval (Gamma, tsm, tsm)
  (* | emtc : forall Gamma (ts: (list (pptn * tm * (list (dptn * tm))))), eval Gamma ((tmtc ts), (tmtc vs)) *)
  | etplmtc : forall Gamma t1 t2 v1 v2, eval (Gamma, t1, v1) -> eval (Gamma, t2, v2) -> eval (Gamma, ttplmtc t1 t2, ttplmtc v1 v2)
  | etmal : forall Gamma M N p L v_v v m_m m_e Delta_v, same_length_list Delta_v v_v -> eval (Gamma, M,v) -> evalmtc Gamma N [(m_m, m_e)] -> evalms3 [[([(p,m_m,m_e,v)], Gamma, empty)]] Delta_v ->
                                                  Forall eval (map (fun t => let '(d,v) := t in (Gamma @@ d, L, v)) (zip Delta_v v_v)) ->
                                                  eval (Gamma, (tmal M N (p, L)), tcll v_v)

  with evalmtc : env -> tm -> list (tm * env) -> Prop :=
  | emtcsm : forall Gamma, evalmtc Gamma tsm [(tsm, Gamma)]
  | emtcmtc : forall Gamma l, evalmtc Gamma (tmtc l) [((tmtc l), Gamma)]
  | emtctpl : forall Gamma m1 m2 n1 n2, eval (Gamma, (ttplmtc m1 m2), (ttplmtc n1 n2)) -> evalmtc Gamma (ttplmtc m1 m2) [(n1, Gamma); (n2, Gamma)]

  with evaldp : dptn -> tm -> option env -> Prop :=
  | edpvar : forall z v, value v -> evaldp (dpvar z) v (Some (z |-> v))
  | edppair : forall p1 p2 v1 v2 g1 g2,
      value v1 -> value v2 -> evaldp p1 v1 (Some g1) -> evaldp p2 v2 (Some g2) ->
      evaldp (dppair p1 p2) (tpair v1 v2) (Some (g1 @@ g2))
  | edpfail : forall t p1 p2, not (is_tpair t) -> evaldp (dppair p1 p2) t None

  with evalpp : pptn -> env -> ptn -> option ((list ptn) * env) -> Prop :=
  | eppdol : forall g p, evalpp ppdol g p (Some ([p], empty))
  | eppvar : forall i g m v, eval (g, m, v) -> evalpp (ppvar i) g (pval m) (Some ([], (i |-> v)))
  | epppair : forall pp1 pp2 p1 p2 g pv1 pv2 g1 g2,
                evalpp pp1 g p1 (Some (pv1,g1)) -> evalpp pp2 g p2 (Some (pv2,g2)) ->
                evalpp (pppair pp1 pp2) g (ppair p1 p2) (Some ((pv1 ++ pv2), (g1 @@ g2)))
  | eppvarfail : forall y g p, not (is_pval p) -> evalpp (ppvar y) g p None
  | epppairfail : forall pp1 pp2 p g, not (is_ppair p) -> evalpp (pppair pp1 pp2) g p None

  with evalms1 : ((list ms) * option env * option (list ms)) -> Prop :=
  | ems1nil : evalms1 ([], None, None)
  | ems1anil : forall sv g d, evalms1 ((([],g,d)::sv), (Some d), (Some sv))
  | ems1 : forall p m mg v av g d sv avv d1,
        evalma (g @@ d) (p,m,mg,v) avv d1 ->
        evalms1 ((((p,m,mg,v)::av, g, d)::sv), None, (Some ((map (fun ai => (ai ++ av, g, d @@ d1)) avv) ++ sv)))

  with evalms2 : (list (list ms)) -> (list env) -> (list (list ms)) -> Prop :=
  | ems2 : forall svv gvv svv1 gvv1 svv2,
      same_length_list3 svv gvv svv1 ->
      Forall evalms1 (zip3 svv gvv svv1) ->
      (filtersome gvv) = gvv1 ->
      (filtersome svv1) = svv2 ->
      evalms2 svv gvv1 svv2

  with evalms3 : (list (list ms)) -> (list env) -> Prop :=
  | ems3nil : evalms3 [[]] []
  | ems3 : forall svv gv svv1 dv gdv, evalms2 svv gv svv1 -> evalms3 svv1 dv -> gdv = gv ++ dv ->
                             evalms3 svv gdv

  with evalma : env -> ma -> list (list ma) -> env -> Prop :=
  | emasome : forall x g v d, evalma g (pvar x, tsm, d, v) [[]] (x |-> v)
  | emappfail : forall p g pp m sv pv d v avv g1,
      evalpp pp g p None -> evalma g (p,(tmtc pv),d,v) avv g1 ->
      evalma g (p,tmtc ((pp,m,sv)::pv),d,v) avv g1
  | emadpfail : forall p g pp m dp n sv pv d v pv1 d1 avv g1,
      evalpp pp g p (Some (pv1, d1)) -> evaldp dp v None ->
      evalma g (p, tmtc ((pp,m,sv)::pv),d,v) avv g1 ->
      evalma g (p, tmtc ((pp,m,(dp,n)::sv)::pv),d,v) avv g1
  | ema : forall p Gamma pp M dp N sigma_v Delta v phi1_v p1_v Delta1 Delta2 v1_vv m1_v,
      evalpp pp Gamma p (Some (p1_v, Delta1)) ->
      evaldp dp v (Some Delta2) ->
      eval (Delta @@ Delta1 @@ Delta2, N, tcll v1_vv) ->
      evalmtc Gamma M m1_v ->
      evalma Gamma (p, tmtc ((pp,M,(dp,N)::sigma_v)::phi1_v), Delta, v)
             ((map (fun tpl => match tpl with
                              | (ttpl v11 v12) => map (fun t => let '(v1, (m1, Gamma1), p1) := t in (p1,m1,Gamma1,v1)) (zip3 [v11;v12] m1_v p1_v)
                              | v11 => map (fun t => let '(v1, (m1, Gamma1), p1) := t in (p1,m1,Gamma1,v1)) (zip3 [v11] m1_v p1_v)
                              end
                   ) v1_vv)) empty.

最も重要なのはeval (Gamma, M, v)で「環境Gammaのもとで項Mが値vに評価される」という命題です。例えば、冒頭のプログラム中の(match-all [1 2] unordered-pair {[<pair $a $b> [a b]]}){[1 2] [2 1]}に評価される、というのは次のようなCoqの命題として表すことができます。

Definition unordered_pair: tm :=
    (tmtc [(pppair ppdol ppdol, ttplmtc tsm tsm,
            [(dppair (dpvar "x") (dpvar "y"), (tcll [(ttpl (tvar "x") (tvar "y")); ttpl (tvar "y") (tvar "x")]))]);
           (ppdol, tsm,
            [(dpvar "tgt", tcll [tvar "tgt"])])]).

Definition match_all_example: tm :=
    (tmal (tpair (tint 1) (tint 2)) unordered_pair (ppair (pvar "a") (pvar "b"),ttpl (tvar "a") (tvar "b"))).

Theorem unordered_pair_example : eval (empty, match_all_example, tcll [ttpl (tint 1) (tint 2);ttpl (tint 2) (tint 1)]).

また、Egisonの操作的意味論上でこの評価結果が正しいということも証明できます。 つまりunordered_pair_exampleを証明することができます。証明は長くなるので省略しますが、Egison.vの230行目以降にあります。

今後の課題

動機の項で述べたようにformalized-egisonの最終的な目標はEgisonの型安全性の証明です。Coqで表現すると

Theorem type_soundness:
    forall Gamma M T, is_typed Gamma M T => exists v, eval Gamma M v /\ is_typed Gamma v T.

となります。ただしis_typed Gamma M Tは環境Gammaの下で項Mに型Tがつくと読みます。 今後はtyped-egisonの型付け規則をCoqで書き直してis_typedを定義した上でtype_soundnessを証明する予定です。

MNISTを可能な限り高速に分類する

概要

  • MNISTの分類をする学習済みモデルを軽量化し、
  • 更にSIMD命令を使った高速化を行うことで、
  • onnxruntimeより高速なMNISTの分類が可能になりました。(シングルスレッドで)

前回までのあらすじ

この記事は前回の続きです。 前回はMNISTを分類する学習済みニューラルネットワークから不要な枝を削除し、軽量化した学習済みモデルを走らせる専用のランタイムsonnxを作ってMNIST分類の高速化を試みました。 今回はSIMD命令とマルチスレッド化による最適化でMNIST分類速度の限界に挑みます。

今回の目標タイムは既存のONNXランタイムonnxruntimeです。 この記事の各実行時間の計測は10回行い、その平均と分散を求めました。 onnxruntimeと最適化無しのsonnxの実行時間は、OS: Ubuntu19.04、CPU: Core i7 8th Gen、 メモリ: 16GiB、GPU無しの環境下で以下のようになりました。1

手法 実行時間
onnxruntime(シングルスレッド) 1.259秒 (標準偏差 0.1148秒)
onnxruntime(マルチスレッド) 0.505秒 (標準偏差 0.04249秒)
sonnx(最適化無し) 6.968秒 (標準偏差 0.08912秒)

機械学習の推論過程、特にMNISTを分類する場合、においてボトルネックになるのは重み行列を入力ベクトルに乗算する処理です。 前回は重み行列の中で絶対値の小さい要素を無視することで計算の効率化を図り、行列の要素の80%を無視して約5倍の高速化に成功しました。 しかし、80%を削減してもなお乗算はボトルネックでした。

具体的にはsonnx.cppのこの部分がボトルネックになります。

for(int i=0;i<n;i++){
        ret[B_row[i]] += B_scale[i] * x[B_column[i]];
}

SIMDによる高速化

まず、sonnxをSIMDで高速化してみます。 Core i7 8th GenではAVX2命令セットが使えるので256bitの演算を一度に行うことができ、 今回は32bit浮動小数点数で計算しているので最大8倍の高速化が見込めます。

しかし、元のコードはメモリへの間接参照を2つ(ret[B_row[i]]x[B_column[i]])含んでおりそのままではSIMD化するのが困難です。 まずret[B_row[i]]への書き込みは同じB_row[i]の値を持つものをまとめて計算し、メモリへの書き込みを一度に行います。 x[B_column[i]]からの読み出しはAVX2のgather命令を使ってSIMD化します。 完成したコードがこれです。

float r = 0;
__m256 vr = _mm256_setzero_ps();
for(cur;cur<n1;cur+=8){
    __m256i vc = _mm256_loadu_si256((__m256i*)(&B_column_p[cur]));
    __m256 vx = _mm256_i32gather_ps(x_p, vc, 4);
    __m256 vs = _mm256_load_ps(&B_scale_p[cur]);
    vr = _mm256_fmadd_ps(vs, vx, vr);
}
for(cur;cur<n;cur++){
    r += B_scale_p[cur] * x_p[B_column_p[cur]];
}
__attribute__((aligned(32))) float t[8] = {0};
_mm256_store_ps(t, vr);
ret[B_row_p[cur-1]] += r + t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7];

実行時間はこんな感じになりました。

手法 実行時間
sonnx(SIMD) 1.121秒(標準偏差 0.02271秒)
onnxruntime(シングルスレッド) 1.259秒(標準偏差 0.1148秒)

ウェルチのt検定を用いて検定すると2有意水準5%で帰無仮説が棄却され、確かにonnxruntimeより高速に推論できています。 onnxruntime(シングルスレッド)と大きな差がつかないのはgather命令が遅いからでしょうか。

マルチスレッドによる高速化

(SIMDなしの)マルチスレッド化はどうでしょうか。 オリジナルのコードはメモリへの書き込みが一箇所(ret[B_row[i]]への書き込み)しかないので、この部分だけ複数スレッドで同時に書き込まないようにすればデータレースは起こりません。 スレッド数は手元で最適なものを探索した結果4にしています。 ソースコードここにあります。

void CompressedGemm::calc_partially(const int index, const vector<float> &x){
    int m = B_nrows_threads[index].size();
    int cur = 0;
    for(int i=0;i<m;i++){
        int n = B_nrows_threads[index][i];
        float r = 0;
        for(int j=cur;j<cur+n;j++){
            r += B_scale_threads[index][j] * x[B_column_threads[index][j]];
        }
        ret[B_row_threads[index][cur]] += r;
        cur += n;
    }
}

vector<float> CompressedGemm::calc(const vector<float> &x){
    ret = C;
    vector<thread> ths;
    
    for(int i=0;i<n_thread;i++){
        ths.push_back(thread(&CompressedGemm::calc_partially, this, i, x));
    }
    for(int i=0;i<n_thread;i++){
        ths[i].join();
    }
    return ret;
}

結果はこんな感じです。

手法 実行時間
sonnx(マルチスレッド) 1.995秒(標準偏差 0.03405秒)
onnxruntime(マルチスレッド) 0.5048秒(標準偏差 0.04249秒)

完敗です。

SIMDとマルチスレッドの併用

では最後にSIMDとマルチスレッディングを併用してみます。 ソースコードここにあります。

手法 実行時間
sonnx(SIMD+マルチスレッド) 1.358秒(標準偏差 0.05762秒)
onnxruntime(マルチスレッド) 0.5048秒(標準偏差 0.04249秒)

あまり効果がありませんね...複数コアからの結果をまとめるのに時間がかかっているのかもしれません。

まとめ

影響力の小さい要素を削除して学習済みのモデルを軽量化した上でSIMDとマルチスレッド化を用いて推論の高速化を試みました。 マルチスレッド環境ではonnxruntimeに勝てませんでしたが、シングルスレッドではonnxruntimeより高速に推論できました。

参考


  1. 前回からOSを入れ替えたので数値が違います。

  2. ちょっとここ正しく議論できているか自信がない

WebAssemblyで自作言語用のGCを書く

前回の続きです。

概要

前回の記事ではWebAssemblyを出力するMLのサブセットコンパイラを作りました。しかし、WebAssemblyにはGabage Collection (GC) が未だに実装されていないため(2019/7/8時点)、メモリ管理は全て自分で行う必要があります。前回は超適当mallocで間に合わせていたのですが、今回はmalloc, freeを含むGCをWebAssembly(手書き)で作りました。

GCを含むMLコンパイラソースコードここで、GC本体はここです。

WebAssemblyの基礎知識

WebAssemblyはスタックマシンで実行される言語であり、各命令は引数をスタックのトップから取り出し結果をスタックのトップに積みます。トップ以外のスタックの中身を自由に見ることはできません1。また、スタックの他にもリニアメモリも持ち、i32.store, i32.loadといった命令でアクセスすることができます。リニアメモリは以下のような特徴を持ちます。

  • アドレスは0から始まる
  • アドレスは32bit
  • 整数又は浮動小数点数をstore/loadできる

GCの実装

GCを書く前にまずmallocとfreeが必要です。今回はfree listでブロックを管理する簡単なものを書きました。各ブロックは下図のような構造になっていて、次のブロックのアドレス、ブロックサイズ、フラグ、確保したメモリを保持します。フラグにはfree listの終端かどうかの情報とブロックが使用中かどうかの情報が入っています。それぞれ4byteずつ使うのでmallocのためのヘッダだけで12byteも使います。贅沢ですね。

+-------------------------------------------------------+
| next block address | block size | flags  | contents   |
| 4bytes             | 4bytes     | 4bytes | 4*n bytes  |
+-------------------------------------------------------+

mallocはfree listを先頭から走査して要求されたサイズ以上の空きブロックがあれば確保、最後まで走査して該当するブロックがなければ末尾に新たにブロックを作ります。freeは指定されたブロックの状態を空きブロックに変えるだけです。空きブロック同士を結合する処理はめんどくさいのでやっていません。

このmallocとfreeを使ってGCを作ります。『ガベージコレクション』によれば、GCには次の3種類があります。

  • マーク・アンド・スイープ
  • コピーGC
  • 参照カウント

このうち、マーク・アンド・スイープとコピーGCは生きているオブジェクトを探索するのにルートセットが必要です。ルートセットは一般にレジスタ、スタック、グローバル変数から構成されます。ところが、WebAssemblyではスタックの中身を自由に見ることができないためルートセットが必要なGCは採用できません2。このため今回は参照カウント方式でGCを実装しました。

gc_malloc を呼び出すとmallocが指定されたサイズ+GCのヘッダサイズ分のメモリをリニアメモリ上に確保し、GCの管理情報を書き込んだ上でそのアドレスを返します。GCの管理領域は下図のようになっていて、確保したメモリのサイズ、参照カウントの値、フラグ、オブジェクトのためのメモリ領域から構成されています。フラグには参照カウント時に行う深さ優先探索用のビットと格納しているオブジェクトが値(整数値又は浮動小数点数)か否かの情報が含まれます。

+----------------------------------------------------+
| memory size | reference count | flags  | contents  |
| 4bytes      | 4bytes          | 4bytes | 4*n bytes |
+----------------------------------------------------+

これがmallocのcontentsの中に入っているので実際にはこのようになります。

+---------------+------------+--------+--------------------------------------------------------+
| next malloc   | malloc     | malloc | malloc contents                                        |
| block address | block size | flags  | +----------------------------------------------------+ |
| 4bytes        | 4bytes     | 4bytes | | memory size | reference count | flags  | contents  | |
|               |            |        | | 4bytes      | 4bytes          | 4bytes | 4*n bytes | |
|               |            |        | +----------------------------------------------------+ |
+---------------+------------+--------+--------------------------------------------------------+  

参照カウントはコンパイラが変数のスコープに応じてオブジェクトの参照カウントを増減させるコード(gc_increase_rc, gc_decrease_rc)を挿入することで操作します。基本的にオブジェクトの作成時に参照カウントが1増加し、そのオブジェクトを指す変数のスコープが終了したところで参照カウントが1減らします。参照カウントが0になるとそのオブジェクトは解放され、またそのオブジェクトが保持していたアドレスが指すオブジェクトの参照カウントを1減らします。この操作は再帰的に行われます。ただし、解放されたオブジェクトが値だった場合はこの再帰的な操作は行われません。

gc_increase_rc, gc_decrease_rcはこんな感じで挿入される

(i32.const 1)
(i32.store)))
(; let fun_fib_0 end ;)
(call $gc_increase_rc)
(get_local $val_b_6)
(call $gc_decrease_rc)
(drop)
(get_local $val_a_5)
(call $gc_decrease_rc)
(drop)

基本的に参照カウントの操作は変数のスコープと連動させればよいのですが、関数終了時には特殊な処理が必要になります。関数内のローカル変数が指すオブジェクトは関数の終端で解放されますが、関数の戻り値が開放されるオブジェクトを含む場合、すでに解放されたオブジェクトが関数の戻り値に含まれることになります。このようなバグを防ぐためには、関数内のローカル変数を開放する処理を行う前に戻り値オブジェクトの参照カウントをインクリメントしておく必要があります。

その他にも配列の要素を書き換えるときは前のオブジェクトの参照カウントを減らしてから、新しく代入するオブジェクトの参照カウントをインクリメントする必要があるなど、細かいところでいろいろハマりました。

結果

フィボナッチ数列を第10項まで計算するプログラムでプログラム終了時のメモリ使用量をGC有りの場合と無しの場合で結果を比較してみると、以下のようになりました。使用メモリの4分の3ぐらいを解放できているのでそれなりに(僕が)満足できる結果です。

- メモリ使用量
GCなし 22868byte
GCあり 4660byte
let rec fib n = 
  if n < 3 
    then 1
    else 
        let a = fib (n - 1) in
        let b = fib (n - 2) in
        a + b in
let rec print x =
  if 10 < x
    then 1
    else
      let a = fib x in
      print_i32 a;
      print (x+1) in
print 0

参考リンク

余談

  • コンパイル時にWebAssemblyで書いたGCをリンクする必要があるのですが、スマートな方法が見つからずテキスト表現のまま無理やりリンクしています。lldを使えばいいのではないか?という情報を頂いたのですが、WebAssemblyのテキスト表現をELF形式に変換する方法がわかりませんでした。
  • WebAssemblyの直書きは辛い...バグると何も言わずに止まっちゃうし。
  • 公式の人は早くGCを組み込んで欲しい。やっぱりスタックを自由に走査できない環境でGCを書くのは無理がある。

  1. 現在のスタックの長さを記録しておけばリニアメモリにスタックの中身をすべて書き出すことはできる。ただしstore/load命令が整数と浮動小数点数で異なるので、スタックに積まれている全ての値について整数なのか浮動小数点数なのかという情報も保存しておく必要がある。

  2. 注1のようにスタックの中身をリニアメモリに書き出せれば、ルートセットが必要な方式も採用できる。

ONNXファイルから不要な枝を削ってMNISTの推論を高速化してみる

この記事の中のソースコードは全てhttps://github.com/akawashiro/sonnxにあります。

概要

背景

機械学習の学習済みモデルを小さなデバイスで動かす、というのが最近流行っているそうです。機械学習では、学習には大きな計算コストがかかりますが、推論はそれほど大きな計算コストがかかりません。このため、学習だけを別のコンピュータで行っておいて、実際の推論は小さなデバイスで行うということが可能です。

ただし、推論だけでもそれなりに計算資源が必要です。そこで、学習済みのモデルの高速化が重要になります。Raspberry Piに搭載されているGPUを使うIdeinとか有名です。

僕も学習済みモデルの推論を高速化できそうな方法を思いついたので実験してみました。

イデア

今回はMNISTを分類する学習済みモデルを高速化します。今回使用するモデルは次の図のようなものです。画像は28*28(=784)pxなので入力は784個、出力は各数字の確率なので10個あり、中間層が2つ挟まっています。各層間は全結合しており、活性化関数としてReluを使います。

このモデルを教師データを使って学習すると、枝の重みが変わってこんな感じになります。

僕のアイデアは学習後のネットワークから重みの小さい枝を取り去ってもちゃんと動くんじゃないか、というものです。重みの小さい枝を取り去るとこんな感じになります。

枝の本数が少なくなれば、推論を高速化できそうな気がします。

イデアの裏付け

実際に学習後のモデルで赤丸で囲った部分の重みの分布を確認します。

分布はこのようになっています。

重みが0の部分が非常に多いです。まず、学習済みモデルから重みが0の枝を削除しても推論結果に影響しないはずです。また、グラフが左右対称になっているので、絶対値の小さい順に削除していけば、各パーセプトロンへの入力はそれほど変化しない気がします。

手法

やることは非常に単純です。ニューラルネットワーク中の各層間枝の重みの統計を取り、重み上位何%かを残して残りを削除します。つまり、何らかの方法で学習済みのモデルから枝の重みのデータを取り出し、枝をカットし、さらに加工後のモデルのデータを使って推論できるようにします。

幸い今はONNXという良いものがあります。ONNXとは学習済みモデルのデータを出力する形式の一つで、多くのフレームワークが対応しています。

今回はChainerで書いたモデルからONNXデータを出力し、そのデータを加工することにしました。また加工後のデータはC++で書いた俺俺ONNXランタイムに実行してもらうことにしました。

纏めると、
1. Chainerでニューラルネットワークを書いて学習する
2. 学習済みのニューラルネットワークからONNXデータを出力する
3. ONNXデータを俺俺ONNXランタイムに読み込んで加工、実行する

となります。一つづつ何をやるのかを説明します。

1. Chainerでニューラルネットワークを書いて学習する

2. 学習済みのニューラルネットワークからONNXデータを出力する

1と2は簡単です。onnx-chainerを使えばすぐにできます。

python3 learn_mnist.py

mnist.onnxというファイルができます。

3. ONNXデータを俺俺ONNXランタイムに読み込んで加工、実行する

ここが大変でした。ディープラーニングフレームワークからのONNXモデルの出力は多くの人が試しているのですが、出力したONNXモデルをチューニングしようとする人はほとんどいないようです。

3.1 ONNXデータを解析する

とりあえずnetronというONNXの可視化ツールでmnist.onnxを可視化してみました。

GemmはGeneral matrix multiplyの略です。各Gemmノードは行列BとベクトルCを持ち、ベクトルxを入力としてBx+Cを出力します。Relu活性化関数です。

Reluはmax(0,x)で定義されている関数のでONNXから抽出する必要は無く、各Gemmノードの行列BとベクトルCの情報だけをを抽出できれば良いです。

今回は各GemmノードのBCをテキストファイルとして抽出します。

> python3 analyze_mnist_onnx.py

は次のファイルを出力します。

  • *************_matrix.txt
    mnist.onnxの全てのGemmノードのB行列とC行列
  • *************_matrix.png
    各行列の中の重みの分布
  • mnist.onnx.json
    ONNXファイルをJSONに変換したもの
  • mnist_train.txt
  • mnist_test.txt
    MNISTをC++から読み込みやすくするためにテキストファイルに変換したもの

*************_matrix.txtがどのGemmノードに対応するのかはmnist.onnx.jsonを睨むとわかります(←ここ超不親切)。

3.2 抽出したデータを加工、実行する
g++ -O3 -mtune=native -march=native -mfpmath=both sonnx.cpp && ./a.out

で出力した重みのデータを読み込み、推論を実行します。sonnx.cppは簡単なONNXランタイムになっており、mnist.onnxから抽出した行列のデータとMNISTの画像データのテキストファイルから推論を行います。デフォルトではmnist_test.txtの10000枚について推論を行います。

出力はこのようになります。

accuracy, time
0.9817000031, 15.93887043
compress_ratio, accuracy, time
0, 0.9817000031, 36.66616821
0.05, 0.9817000031, 34.61940384
0.1, 0.9818000197, 32.64707184
0.15, 0.9818000197, 30.78090286
0.2, 0.9817000031, 29.01914406
..........................

2行目は読み込んだデータを加工せずで隣接行列表現で保持したときの推論の精度と計算時間、4行目以降は読み込んだデータを加工したときの推論の圧縮率(compress_ratio)、精度と計算時間です。圧縮率が0のときは全く加工しないのと同じ、圧縮率が0.3のときは枝の重みの絶対値が小さい30%を削除したときの結果です。圧縮率が1になると全ての枝が削除されます。

データを加工して枝を削除するとき、sonnx.cppでは行列データを非零成分のインデックスとその値の組の配列として保持しています。これは枝を削除したときに保持するデータ量が減らし、高速に計算するためです。

結果

圧縮率を変化させたときの精度と計算時間のグラフです。圧縮率を0.8まで上げても推論の精度が変わっていません。これは学習済みモデル中の8割の枝を削除しても、推論精度が保てるということです。驚きですね!

一方、計算時間はほぼ枝の数に反比例しています。これは予想通りでした。

表にしてみます。推論精度を1%犠牲にするだけで2倍も高速化できています。やったね!

圧縮率 推論精度 計算時間(秒)
0 0.981 15.9
0.8 0.971 7.04

※ 圧縮率0のときの値は重みのデータを隣接行列表現で保持したときのものです。

おまけ

最後にonnxruntimeで元の学習済みモデルを実行したときの時間を計ってみます。

> python3 analyze_mnist_onnx.py
Accuracy rate =  0.9817 , time =  3.285153865814209

え、3秒?? 普通に考えると16秒ぐらいになるはずです。なぜこんなに速いんだ...

理由は色々あると思いますが、僕が思いつくのは以下の2つです。

  • AVXなどのSIMD拡張命令を使っている
    (僕もAVXを使おうとしましたが、scatter命令が僕のPCで使えないので諦めました。)
  • CPUのキャッシュに乗るようなプログラムになっている

関連記事・研究

まとめ

  • ONNXランタイムを自作したらナイーブな実装の二倍の速度で推論ができた
  • onnxruntimeには勝てなかった

高階関数を基本的な関数の合成で作った関数でQuickCheckする

高階関数をQuickCheckでテストしてみる

QuickCheckを知っていますか?
QuickCheckと言うのはHaskellのデータ駆動型のテスト用ライブラリで
テストしたい関数を指定するとその引数に合わせて適当なテストデータを生成してくれます。

では高階関数(関数を引数に取る関数)をQuickCheckでテストすると
どんなテストデータ(関数)を生成してくれるのでしょうか。
ちょっと確認してみましょう。

-- QCTest.hs
import Test.QuickCheck
import Test.QuickCheck.Function

prop :: Fun Integer Integer -> Bool
prop f = apply f 10 == apply f 20

main = quickCheck prop
% stack runghc QCTest.hs
*** Failed! Falsifiable (after 2 tests and 12 shrinks):
{20->0, _->1}

...なんだこの関数は?
入力と出力の組が列挙されてるだけやんけ。
僕の思ってる関数と違うんですが。
map (+2)とかtailとかそういうやつを生成してほしいなぁ。

まともなテスト用関数を生成する

QuickCheckのテストデータが気に入らないので、まともな(主観)テストデータを生成してみました。
今回生成するテストデータは[Int] -> [Int]の関数に限定します。

そもそも関数を生成するってどうやればいいんでしょうか?
QuickCheckのように入力と出力の対を列挙するのは筋が悪い気がします。
mapや(+)、tailといった基本的な関数を合成して関数を生成したいのです。
我々がプログラミングするときってそうやりますよね。
入力と出力の組を列挙する人はあんまりいないと思います。

[Int] -> [Int]型の関数を列挙してみます。

tail
reverse
map (* 10)
map (+ 10)

tailやreverseはそもそも[Int] -> [Int] 型です。
mapは(Int -> Int) -> [Int] -> [Int] 型なので
(+ 10) :: Int -> Intを渡すと[Int] -> [Int]型になります。

こんな風にどの関数(例 map)にどの関数(例 (+ 10))を適応すれば
どんな型(例 [Int] -> [Int])になるかを規則として書き出します。
型付け規則っぽく書くとこんな感じです。
Imgur
ちょっとわかりにくいかも知れません。

[Int] -> [Int]型の関数を2つ作ってみました。
Imgur
左はmap (+ 10)のような関数、右はmap (+ 10 (* 5))のような関数を表しています。
一番下の型にある関数の規則を持ってきて、上に必要な型を書いて
さらにその必要な型をもつ関数の規則を持ってきて...という感じです。

とりあえずこの規則を使って関数を生成できるようになりました。
[Int] -> [Int]型の関数を生成した例がこちらです。

% stack exec qcfun-exe
> Test datum are following.
> ["tail","reverse","(map (* 75))","tail","reverse"]
...

今見せた例はこのようなデータ型で表現されています。

data QProg = QMap1 QProg | QMap2 QProg QProg 
             | QTail1 QProg | QTail | QRev | QMult QProg | QAdd QProg | QRand Int deriving (Eq)

instance Show QProg where
  show (QMap1 p) = "(map " ++ show p ++ ")"
  show (QMap2 p1 p2) = "(map " ++ show p1 ++ " " ++ show p2 ++ ")"
  show (QTail1 p) = "(tail " ++ show p ++ ")"
  show (QTail) = "tail"
  show QRev = "reverse"
  show (QMult p) = "(* " ++ show p ++ ")"
  show (QAdd p) = "(* " ++ show p ++ ")"
  show (QRand i) = show i

QProg型がプログラム(関数)を表す型です。
このQProg型に対して適当なShowインスタンスを定義してあげると関数っぽく見えるようになります。

作ったテストデータ用の関数で高階関数をテストする

次に作った関数を使って高階関数をテストしてみましょう。
今回テストする対象のプログラムはこれです。

func f = reverse . f
prop f = (func f [1,2,3]) == [1,2,3]

関数fを引数に取る簡単な高階関数funcの性質をテストします。

しかし作ったテストデータ用の関数はすべてQProg型です。
つまりQProg型をなんとかして[Int] -> [Int]型に変換したうえで
テストしたい高階関数に適用する必要があります。

今回はTemplate Haskellを使ってQProg型を[Int]->[Int]型のHaskellの関数に変換しました。
Template Haskellというのはプログラムの中でプログラムを作るためのツールです。
Lensなんかで使われています。 動作例がこちらです。

% stack exec qcfun-exe
...
Test results are following.
[(False,"tail"),(True,"reverse"),(False,"(map (* 75))"),(False,"tail"),(True,"reverse")]

テストデータは先頭から順に["tail","reverse","(map (* 75))","tail","reverse"]だったので正常に動いているようです。

まとめ

高階関数をまともな(主観)関数でテストできるようにしました。
しかし技術的な制約のためテストデータをコンパイル時に生成するので
テストデータを完全にランダムに生成することはできませんでした。

参考

http://haskell.g.hatena.ne.jp/mr_konn/20111218/1324220725

付録

githubへのリンクです
https://github.com/akawashiro/qcfun
以下のコマンドで実行できます。

git clone https://github.com/akawashiro/qcfun.git
stack build
stack exec qcfun-exe

ソースコードはTemplate Haskell的な事情で2つにわかれています。