gcc ビルトイン関数の呼び出し

ダイナミックリンカのソースコード(elf/rtld.c)を見ていると興味深いコメントがあった。

  /* Partly clean the `bootstrap_map' structure up.  Don't use
     `memset' since it might not be built in or inlined and we cannot
     make function calls at this point.  Use '__builtin_memset' if we
     know it is available.  We do not have to clear the memory if we
     do not have to use the temporary bootstrap_map.  Global variables
     are initialized to zero by default.  */

memsetはビルドインでないもしくはインライン化されていない可能性があるため、ここでは呼び出してはいけない」。このコメントはダイナミックリンカの起動直後にあるため、再配置が終わっていない関数を呼び出すとアドレスが埋まっていないGOTを参照してしまい、Segmentation Faultが発生することを危惧したものだと考えられる。

ということはgccのビルトイン関数は必ずインライン化されなければならない。実際gccのビルトイン関数の説明には以下のような記述がある。

With the exception of built-ins that have library equivalents such as the standard C library functions discussed below, or that expand to library calls, GCC built-in functions are always expanded inline and thus do not have corresponding entry points and their address cannot be obtained. Attempting to use them in an expression other than a function call results in a compile-time error.

ビルトイン関数を使う小さなプログラムでも実験的に確かめることができる。

> cat main.c
#include <stdio.h>

int main() {
    int c = __builtin_clz(0x10101);
    printf("c = %d\n", c);
    return 0;
}
> gcc main.c -static -o main
> objdump -d main | grep "<main>" -A 16
0000000000401ce5 <main>:
  401ce5:       f3 0f 1e fa             endbr64 
  401ce9:       55                      push   %rbp
  401cea:       48 89 e5                mov    %rsp,%rbp
  401ced:       48 83 ec 10             sub    $0x10,%rsp
  401cf1:       c7 45 fc 0f 00 00 00    movl   $0xf,-0x4(%rbp)
  401cf8:       8b 45 fc                mov    -0x4(%rbp),%eax
  401cfb:       89 c6                   mov    %eax,%esi
  401cfd:       48 8d 3d 00 33 09 00    lea    0x93300(%rip),%rdi        # 495004 <_IO_stdin_used+0x4>
  401d04:       b8 00 00 00 00          mov    $0x0,%eax
  401d09:       e8 92 ee 00 00          callq  410ba0 <_IO_printf>
  401d0e:       b8 00 00 00 00          mov    $0x0,%eax
  401d13:       c9                      leaveq 
  401d14:       c3                      retq   
  401d15:       66 2e 0f 1f 84 00 00    nopw   %cs:0x0(%rax,%rax,1)
  401d1c:       00 00 00 
  401d1f:       90                      nop

__builtin_clzはインライン化されているようだ。

LD_AUDITとGlobal Offset Table

GNU製のLinux向け動的リンカld-linux.so には環境変数 LD_AUDIT 経由で使える監視APIがあります。このAPIを使えば、ld-linux.soの様々な挙動にフックして、その挙動を監視・干渉することができます。また、興味深いことにLD_AUDITを有効にするとGOT(Global Offset Table)周りの挙動も変化します。

PLT(Procedure Linkage Table)やGOTについて知らない場合はこちらの記事が分かりやすいです。この記事のglibcはすべてコミット 137ed5ac44を参照しています。

動作例

LD_AUDITを使ってltrace(1)まがいのものを作ってみます。以下の3つのファイルを作ります。なお、このサンプルはLinux glibcで、LD_AUDIT機能による関数トレースを参考にしました。

audit.c

#define _GNU_SOURCE
#include <link.h>
#include <stdio.h>

unsigned int la_version(unsigned int version) { return LAV_CURRENT; }

unsigned int la_objopen(struct link_map* map, Lmid_t lmid, uintptr_t* cookie) {
    printf("[LD_AUDIT] %s loaded\n", map->l_name);
    return LA_FLG_BINDTO | LA_FLG_BINDFROM;
}

ElfW(Addr) la_x86_64_gnu_pltenter(ElfW(Sym) * sym, unsigned int ndx,
                                  uintptr_t* refcook, uintptr_t* defcook,
                                  La_x86_64_regs* regs, unsigned int* flags,
                                  const char* symname, long* framesizep) {
    printf("[LD_AUDIT] %s entering\n", symname);
    return sym->st_value;
}

hello.c

#include <stdio.h>

int main() {
    puts("Hello World!");
    puts("Hello World!");
    puts("Hello World!");
    return 0;
}

run.sh

#! /bin/bash -eux

gcc -o libaudit.so -shared -fpic -fPIC -Wl,-soname,libaudit.so audit.c
gcc -o hello hello.c

run.sh を実行すると以下のような結果が得られて、ltraceっぽいことができているのがわかります。

> ./run.sh
+ gcc -o libaudit.so -shared -fpic -fPIC -Wl,-soname,libaudit.so audit.c
+ gcc -o hello hello.c
+ LD_AUDIT=libaudit.so
+ ./hello
[LD_AUDIT]  loaded
[LD_AUDIT] /lib64/ld-linux-x86-64.so.2 loaded
[LD_AUDIT] /lib/x86_64-linux-gnu/libc.so.6 loaded
[LD_AUDIT] _dl_find_dso_for_object entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] __tunable_get_val entering
[LD_AUDIT] puts entering
Hello World!
[LD_AUDIT] puts entering
Hello World!
[LD_AUDIT] puts entering
Hello World!
+ LD_BIND_NOW=1
+ LD_AUDIT=libaudit.so

解説

la_versionLD_AUDITを使うために必ず必要な関数、la_objopenは新しいshared objectがロードされたときに呼び出される関数であり、今回は深入りしません。

la_x86_64_gnu_pltenterは PLT(Procedure Linkage Table)がGOT(Global Offset Table)を埋めるときに呼び出される関数であり、具体的にはここで呼び出されます。

この実行結果で興味深いのは [LD_AUDIT] puts entering が複数回出力されていることです。一度PLTのエントリがGOTに正しいアドレスを埋めた後は、putsの呼び出しはlibcのアドレスを埋める部分を通らず、libc中のputs本体に直接jmpしla_x86_64_gnu_pltenterを呼び出すとことはできないはずです。

ここから考えるとLD_AUDITが設定されているときはGOTにアドレスを埋める処理が行われない、と推測できます。glibcの実装を確認してみましょう。

また、dl-machine.hにはprofiling extensionが使用されている場合、GOTを書き換えない旨のコメントがあります。

/ The got[2] entry contains the address of a function which gets called to get the address of a so far unresolved function and jump to it. The profiling extension of the dynamic linker allows to intercept the calls to collect information. In this case we don't store the address in the GOT so that all future calls also end in this function. /

参考

GNU_IFUNCとは何か

GNU_IFUNC(GNU indirect function)とは同一の関数の複数の実装から、ロード時に最適な実装を選択する仕組みです1。通常、同一の関数の複数の実装が存在した場合、動的リンカ(ld-linux.so)が最初に見つけた実装が選択されますが、その選択するための基準を開発者が決定できるところに、この機能の意義があります。

GNU_IFUNCを使った簡単な例を以下に示します。resolve_foofoo の実装が選択されていることがわかります。

> cat foo.c 
#include <stdio.h>

extern void foo();
void foo_default() { printf("foo_default\n"); }
void foo_1() { printf("foo_1\n"); }
void foo_2() { printf("foo_2\n"); }

void foo() __attribute__((ifunc("resolve_foo")));

static void *resolve_foo(void) {
    if (0)
        return foo_1;
    else if (42 == 41 + 1)
        return foo_2;
    else
        return foo_default;
}
> cat main.c 
void foo();

int main() {
    foo();
    return 0;
}
> gcc -shared -fpic -fPIC foo.c -o libfoo.so
> gcc -o main main.c libfoo.so
> ./main
foo_2

glibcでは memmove, memset, memcpy, strcmp, strstr などの関数でCPUの機能に応じて高度に最適化された実装をロード時に選択するために使われています2。例えば x86-64memcpy の場合はAVX512等のSIMD拡張が使えるかどうかを判定し、最適な実装を選択するコードが入っています。3

/*  glibc/sysdeps/x86_64/multiarch/ifunc-impl-list.c  */
 IFUNC_IMPL (i, name, __memcpy_chk,
          IFUNC_IMPL_ADD (array, i, __memcpy_chk,
                  CPU_FEATURE_USABLE (AVX512F),
                  __memcpy_chk_avx512_no_vzeroupper)
          IFUNC_IMPL_ADD (array, i, __memcpy_chk,
                  CPU_FEATURE_USABLE (AVX512VL),
                  __memcpy_chk_avx512_unaligned)
          IFUNC_IMPL_ADD (array, i, __memcpy_chk,
                  CPU_FEATURE_USABLE (AVX512VL),
                  __memcpy_chk_avx512_unaligned_erms)
          IFUNC_IMPL_ADD (array, i, __memcpy_chk,
                  CPU_FEATURE_USABLE (AVX),
                  __memcpy_chk_avx_unaligned)

最後にGNU_IFUNCの実装選択関数を悪用してみます。glibcGNU_IFUNCに関する再配置情報を処理する部分を見ると、ロード時に実装を選択するための関数が呼ばれていることがわかります4

/* glibc/sysdeps/x86_64/dl-machine.h */
else if (__glibc_unlikely (r_type == R_X86_64_IRELATIVE))
    {
      ElfW(Addr) value = map->l_addr + reloc->r_addend;
      if (__glibc_likely (!skip_ifunc))
    value = ((ElfW(Addr) (*) (void)) value) ();
      *reloc_addr = value;
    }

つまり、__attribute__(constructor)5 と同様に main() の前に関数を呼び出す手段として使えそうです。やってみましょう。

> cat foo.c 
#include <stdio.h>

extern void foo();
void foo_default() { printf("foo_default\n"); }
void foo_1() { printf("foo_1\n"); }
void foo_2() { printf("foo_2\n"); }

void foo() __attribute__((ifunc("resolve_foo")));

static void *resolve_foo(void) {
    printf("resolve_foo\n");
    if (0)
        return foo_1;
    else if (42 == 41 + 1)
        return foo_2;
    else
        return foo_default;
}
> cat main.c 
#include <stdio.h>

void foo();

int main() {
    printf("beginning of main()\n");
    foo();
    return 0;
}
> gcc -shared -fpic -fPIC foo.c -o libfoo.so
> gcc -o main main.c libfoo.so
> LD_BIND_NOW=1
> ./main
resolve_foo
beginning of main()
foo_2

R_X86_64_IRELATIVEmain() 起動後に遅延して解決されるのを防ぐために LD_BIND_NOW=1 が必要でしたが、無事 main() の前に resolve_foo を呼び出せたことがわかります。

setjmp, longjmpで使われるjmp_bufには何がどのように保存されているのか

setjmp, longjmpはC言語で実行コンテキストを保存し、保存したコンテキストに復帰するために使われる関数(orマクロ)である。g++の例外実装の一つはこのsetjmp, longjmpを用いて実現されている。

この記事では 実行コンテキスト とは何を指すのか? また、その保存方法の実装について調べる。

jmp_buf

setjmpの実装はglibcsetjmp/setjmp.h にある。

typedef struct __jmp_buf_tag jmp_buf[1];

/* Store the calling environment in ENV, also saving the signal mask.
   Return 0.  */
extern int setjmp (jmp_buf __env) __THROWNL;

このjmp_bufsetjmp/bits/types/struct___jmp_buf_tag.h で定義されている。

/* Calling environment, plus possibly a saved signal mask.  */
struct __jmp_buf_tag
  {
    /* NOTE: The machine-dependent definitions of `__sigsetjmp'
       assume that a `jmp_buf' begins with a `__jmp_buf' and that
       `__mask_was_saved' follows it.  Do not move these members
       or add others before it.  */
    __jmp_buf __jmpbuf;     /* Calling environment.  */
    int __mask_was_saved;  /* Saved the signal mask?  */
    __sigset_t __saved_mask;    /* Saved signal mask.  */
  };

x86-64の場合、__jmp_bufsysdeps/x86/bits/setjmp.hで定義されている。__jmp_bufは8個のint型整数からなり、更に64bitかどうかで場合分けがなされていることがわかる。

# if __WORDSIZE == 64
typedef long int __jmp_buf[8];
# elif defined  __x86_64__
__extension__ typedef long long int __jmp_buf[8];
# else
typedef int __jmp_buf[6];
# endif

setjmp

setjmp自体の実装はsysdeps/x86_64/setjmp.Sにある。Linux x86-64では第一引数を%rdiで渡すので、第一引数として渡された構造体にレジスタの値を保存していることがわかる。保存されているのは、%rbx, %r12, %r13, %r14, %r15, %rsp, PCである。しかし、%rsp, %rbp, PCはPTR_MANGLEで処理されてから保存されている。PTR_MANGLEとは何だろうか? また、なぜ必要なのだろう?

ENTRY (__sigsetjmp)
    /* Save registers.  */
    movq %rbx, (JB_RBX*8)(%rdi)
#ifdef PTR_MANGLE
# ifdef __ILP32__
    /* Save the high bits of %rbp first, since PTR_MANGLE will
       only handle the low bits but we cannot presume %rbp is
       being used as a pointer and truncate it.  Here we write all
       of %rbp, but the low bits will be overwritten below.  */
    movq %rbp, (JB_RBP*8)(%rdi)
# endif
    mov %RBP_LP, %RAX_LP
    PTR_MANGLE (%RAX_LP)
    mov %RAX_LP, (JB_RBP*8)(%rdi)
#else
    movq %rbp, (JB_RBP*8)(%rdi)
#endif
    movq %r12, (JB_R12*8)(%rdi)
    movq %r13, (JB_R13*8)(%rdi)
    movq %r14, (JB_R14*8)(%rdi)
    movq %r15, (JB_R15*8)(%rdi)
    lea 8(%rsp), %RDX_LP    /* Save SP as it will be after we return.  */
#ifdef PTR_MANGLE
    PTR_MANGLE (%RDX_LP)
#endif
    movq %rdx, (JB_RSP*8)(%rdi)
    mov (%rsp), %RAX_LP /* Save PC we are returning to now.  */
    LIBC_PROBE (setjmp, 3, LP_SIZE@%RDI_LP, -4@%esi, LP_SIZE@%RAX_LP)
#ifdef PTR_MANGLE
    PTR_MANGLE (%RAX_LP)
#endif
    movq %rax, (JB_PC*8)(%rdi)

#ifdef SHADOW_STACK_POINTER_OFFSET
# if IS_IN (libc) && defined SHARED && defined FEATURE_1_OFFSET
    /* Check if Shadow Stack is enabled.  */
    testl $X86_FEATURE_1_SHSTK, %fs:FEATURE_1_OFFSET
    jz L(skip_ssp)
# else
    xorl %eax, %eax
# endif
    /* Get the current Shadow-Stack-Pointer and save it.  */
    rdsspq %rax
    movq %rax, SHADOW_STACK_POINTER_OFFSET(%rdi)
# if IS_IN (libc) && defined SHARED && defined FEATURE_1_OFFSET
L(skip_ssp):
# endif
#endif
#if IS_IN (rtld)
    /* In ld.so we never save the signal mask.  */
    xorl %eax, %eax
    retq
#else
    /* Make a tail call to __sigjmp_save; it takes the same args.  */
    jmp __sigjmp_save
#endif
END (__sigsetjmp)

PTR_MANGLEは sysdeps/unix/sysv/linux/x86_64/sysdep.h に定義がある。

#  define PTR_MANGLE(reg)    xor %fs:POINTER_GUARD, reg;           \
                rol $2*LP_SIZE+1, reg
#  define PTR_DEMANGLE(reg) ror $2*LP_SIZE+1, reg;                \
                xor %fs:POINTER_GUARD, reg

Linux x86-64では%fsレジスタがThread Local Storageのベースポインタとして使われることを思い出すと、このコードは

  • スレッド固有の値(%fs:POINTER_GUARD)とのxorを取り
  • $2*LP_SIZE+1だけレジスタのビットを回転させる

ことがわかる。PTR_MAGNLEの詳細は Pointer Encryptionにある。

なぜsetjmp, longjmpに際してポインタを暗号化するのだろうか? 最近のglibcではatexit関数やjmp_bufを狙った攻撃は効かない (PTR_MANGLE)Pointer Subterfugeによれば、書き込み可能な領域に生のEIPを保存すると攻撃対象になるからだそうだ。

結論

  • setjmpは%rbx, %r12, %r13, %r14, %r15, %rsp, PCをjmp_bufに保存する
  • このとき、生のEIPを保存するとそのデータ自体が攻撃目標になるため、PTR_MANGLEで暗号化する

検索エンジンを自作する夢

そろそろ学生も終わるので、1年前の失敗と反省を書き留めておこうと思います。私達(@a_kawashiro@gky360)は2018年度の未踏プロジェクトで分野限定型検索エンジンを複数組み合わせた分散型検索エンジンとして採用され、9ヶ月間を検索エンジンの作成に費やしました。この記事ではプロジェクトの紹介と、作成した検索エンジンが満足な性能を出せなかった原因について述べます。

膨大なインターネットの情報の海の中で目的の情報に辿り着くためには、検索エンジンが必要不可欠です。しかし世の中のメジャーな検索エンジンは大企業が提供してくれているものばかりで、ユーザが自分で管理・設定できるものはほとんどありません。それなら一つ作ってみるか、ということで始めたのがこのプロジェクトでした。

検索エンジンは大まかに言って3つの機能で構成されます。

  • クロール: インターネット上からウェブページをダウンロードして情報を蓄積する
  • 検索: 蓄積した情報のなかからユーザのクエリにマッチするものを探し出す
  • 表示: マッチするものを適切に並び替えて表示する

どれも最低限の実装をするならそれほど難しくないので、検索エンジンを作るのはそれほど難しくないように思われます。

しかし、現在のインターネットはあまりに巨大です1千里眼やMondouが作られた時代はせいぜい300万サイトしかありませんでしたが、今は16億サイトを超えていて素直な方法で検索エンジンを作ってもまともな結果が帰ってこないのは明らかです。

Googleのような巨大なデータセンターを使えればよいのですが、お金に余裕がなかったので、インターネット上の個人のサーバをうまく協力させて一つの検索エンジンを構成することにしました。このような検索エンジンは分散検索エンジンと呼ばれ、yacyなどがよく知られています。yacyは完全なpeer to peer型の分散検索エンジンですが、私達は

  1. 数多くのwebページを分野ごとに分類し各分野ごとの検索エンジンを作り、
  2. 分野ごとの検索エンジンを複数つなぎ合わせることでより大きな検索エンジンを作る

という設計を採用することにしました。この設計のメリットは、

  • 各分野ごとの検索エンジンはその分野に特化することで、大きなリソースを使わずに精度の高い検索サービスを提供でき、
  • 更にそれらを複数つなぎ合わせることでより広範囲の検索サービスを段階的に構築できる

の2点です。完全なpeer to peer型の分散検索エンジンに比べて、ある程度の階層構造があることで検索精度の改善が期待できます。

この設計を元に実装したのが、kearchです。kearchには2つの検索エンジンの実装が含まれます。一つは専門検索エンジン(各分野ごとの検索エンジン)、もう一つはメタ検索エンジン(専門検索エンジンをつなぎ合わせるための検索エンジン)です。

ユーザからのクエリはメタ検索エンジンによって適切な専門検索エンジンに割り振られ、ユーザはその専門検索エンジンの結果をメタ検索エンジンを通して受け取ります。例えばこの検索エンジンの構成で、"pip3 install flask"と検索するとPythonの専門検索エンジンサーバにクエリが飛んでPythonに関する検索結果が返ってきます。

f:id:a_kawashiro:20200304225906p:plain
kearchを使った検索エンジンの構成例

2019年1月から3月にかけて、私達はさくらクラウド上で上の図の検索エンジンを実際に運用してみました。確かに専門検索エンジンメタ検索エンジンも動くのですが実用には程遠い検索性能でした。その理由は以下の2つです。

まず、1つ目の「専門検索エンジンの数が足りない」ですが、これはつまり、人間が日常的に検索する分野の範囲が広すぎるということです。ブラウザの履歴を確認してもらうとわかるのですが、1人の人間が1ヶ月で検索する分野は非常に広大です。明らかに4台では足りませんでした。この問題に対する解決策は未だに思いついていません。

2つ目の「専門検索エンジンのクロール性能が低い」は具体的にはwebページの分類性能が低いということです。現在のkearchでは該当分野の単語リストを使ってwebページを分類するのですが、単語リストだけでは十分な性能が出ません。該当分野と全く関係ないwebページがクロールされていることもよくありました。

2つ目の問題にはいくつかの解決策が考えられます。例えば、該当分野の論文を読み込ませて単語リストを拡充したり、信頼できるwebページからのホップ数を考慮することで分類精度を改善できるでしょう。またFessなどのエンタープライズサーチ向けの製品を専門検索エンジンとして使えるように改造するのも解決策の一つだと思います。

最後に、専門検索エンジンについては様々な問題を感じる一方、メタ検索エンジンについてはさほど問題を感じませんでした。「小さな検索エンジンをつなぎ合わせてより大きな検索エンジンを作る」という基本設計は間違っていないと今でも思っています。

formalized-egison -- Egisonの型安全性の証明に向けて

概要

Egisonのパターンマッチの動作を定理証明支援系Coq上で定義し、「Egisonの処理系が項Mを値vに評価する」という命題をCoq上で表現できるようにしました。

動機

Egisonの中核的な機能はユーザが拡張可能なパターンマッチ機構です。例えば以下に示すように、ペアに対してその順序を無視してパターンマッチすることが可能です。

(define $unordered-pair
  (matcher
    {[<pair $ $> [something something]
      {[[$x $y] {[x y] [y x]}]}]
     [$ something
      {[$tgt {tgt}]}]}))

(match-all [1 2] unordered-pair {[<pair $a $b> [a b]]})
; ===> {[1 2] [2 1]}

しかし、その拡張性のためにこのパターンマッチの動作は大変複雑です。Egi, Nishiwaki (2018)はEgisonのパターンマッチの動作をbig-step styleの操作的意味論として次のように定義しています。 Egisonの操作的意味論

型健全性などのEgisonの性質を証明する際はこの操作的意味論を使うことになりますが、規則の種類、数が多いので証明を手で書くと間違えてしまいそうです。そこで今回の記事ではEgi, Nishiwaki (2018)の操作的意味論をCoqに書き直しました。

formalized-egison

formalized-egisonEgi, Nishiwaki (2018)によるEgisonの操作的意味論をCoqに書き直したものです。意味論を定義している部分はEgison.vの以下の部分です。

Inductive eval : (env * tm * tm)-> Prop :=
  | evarin : forall i Gamma t, (Gamma i) = Some t -> eval (Gamma, (tvar i), t)
  | evarout : forall i Gamma, (Gamma i) = None -> eval (Gamma, (tvar i), (tvar i))
  | eint : forall i e, eval (e, (tint i), (tint i))
  | etpl : forall Gamma t1 t2 v1 v2, eval (Gamma, t1, v1) -> eval (Gamma, t2, v2) -> eval (Gamma, ttpl t1 t2, ttpl v1 v2)
  | ecll : forall e ts vs, same_length_list ts vs ->
                      Forall eval (map (fun tpl => let '(t,v) := tpl in (e,t,v)) (zip ts vs)) -> eval (e, (tcll ts), (tcll vs))
  | epair : forall e t1 t2 v1 v2, eval (e, t1, v1) -> eval (e, t2, v2) -> eval (e, (tpair t1 t2), (tpair v1 v2))
  | esm : forall Gamma, eval (Gamma, tsm, tsm)
  (* | emtc : forall Gamma (ts: (list (pptn * tm * (list (dptn * tm))))), eval Gamma ((tmtc ts), (tmtc vs)) *)
  | etplmtc : forall Gamma t1 t2 v1 v2, eval (Gamma, t1, v1) -> eval (Gamma, t2, v2) -> eval (Gamma, ttplmtc t1 t2, ttplmtc v1 v2)
  | etmal : forall Gamma M N p L v_v v m_m m_e Delta_v, same_length_list Delta_v v_v -> eval (Gamma, M,v) -> evalmtc Gamma N [(m_m, m_e)] -> evalms3 [[([(p,m_m,m_e,v)], Gamma, empty)]] Delta_v ->
                                                  Forall eval (map (fun t => let '(d,v) := t in (Gamma @@ d, L, v)) (zip Delta_v v_v)) ->
                                                  eval (Gamma, (tmal M N (p, L)), tcll v_v)

  with evalmtc : env -> tm -> list (tm * env) -> Prop :=
  | emtcsm : forall Gamma, evalmtc Gamma tsm [(tsm, Gamma)]
  | emtcmtc : forall Gamma l, evalmtc Gamma (tmtc l) [((tmtc l), Gamma)]
  | emtctpl : forall Gamma m1 m2 n1 n2, eval (Gamma, (ttplmtc m1 m2), (ttplmtc n1 n2)) -> evalmtc Gamma (ttplmtc m1 m2) [(n1, Gamma); (n2, Gamma)]

  with evaldp : dptn -> tm -> option env -> Prop :=
  | edpvar : forall z v, value v -> evaldp (dpvar z) v (Some (z |-> v))
  | edppair : forall p1 p2 v1 v2 g1 g2,
      value v1 -> value v2 -> evaldp p1 v1 (Some g1) -> evaldp p2 v2 (Some g2) ->
      evaldp (dppair p1 p2) (tpair v1 v2) (Some (g1 @@ g2))
  | edpfail : forall t p1 p2, not (is_tpair t) -> evaldp (dppair p1 p2) t None

  with evalpp : pptn -> env -> ptn -> option ((list ptn) * env) -> Prop :=
  | eppdol : forall g p, evalpp ppdol g p (Some ([p], empty))
  | eppvar : forall i g m v, eval (g, m, v) -> evalpp (ppvar i) g (pval m) (Some ([], (i |-> v)))
  | epppair : forall pp1 pp2 p1 p2 g pv1 pv2 g1 g2,
                evalpp pp1 g p1 (Some (pv1,g1)) -> evalpp pp2 g p2 (Some (pv2,g2)) ->
                evalpp (pppair pp1 pp2) g (ppair p1 p2) (Some ((pv1 ++ pv2), (g1 @@ g2)))
  | eppvarfail : forall y g p, not (is_pval p) -> evalpp (ppvar y) g p None
  | epppairfail : forall pp1 pp2 p g, not (is_ppair p) -> evalpp (pppair pp1 pp2) g p None

  with evalms1 : ((list ms) * option env * option (list ms)) -> Prop :=
  | ems1nil : evalms1 ([], None, None)
  | ems1anil : forall sv g d, evalms1 ((([],g,d)::sv), (Some d), (Some sv))
  | ems1 : forall p m mg v av g d sv avv d1,
        evalma (g @@ d) (p,m,mg,v) avv d1 ->
        evalms1 ((((p,m,mg,v)::av, g, d)::sv), None, (Some ((map (fun ai => (ai ++ av, g, d @@ d1)) avv) ++ sv)))

  with evalms2 : (list (list ms)) -> (list env) -> (list (list ms)) -> Prop :=
  | ems2 : forall svv gvv svv1 gvv1 svv2,
      same_length_list3 svv gvv svv1 ->
      Forall evalms1 (zip3 svv gvv svv1) ->
      (filtersome gvv) = gvv1 ->
      (filtersome svv1) = svv2 ->
      evalms2 svv gvv1 svv2

  with evalms3 : (list (list ms)) -> (list env) -> Prop :=
  | ems3nil : evalms3 [[]] []
  | ems3 : forall svv gv svv1 dv gdv, evalms2 svv gv svv1 -> evalms3 svv1 dv -> gdv = gv ++ dv ->
                             evalms3 svv gdv

  with evalma : env -> ma -> list (list ma) -> env -> Prop :=
  | emasome : forall x g v d, evalma g (pvar x, tsm, d, v) [[]] (x |-> v)
  | emappfail : forall p g pp m sv pv d v avv g1,
      evalpp pp g p None -> evalma g (p,(tmtc pv),d,v) avv g1 ->
      evalma g (p,tmtc ((pp,m,sv)::pv),d,v) avv g1
  | emadpfail : forall p g pp m dp n sv pv d v pv1 d1 avv g1,
      evalpp pp g p (Some (pv1, d1)) -> evaldp dp v None ->
      evalma g (p, tmtc ((pp,m,sv)::pv),d,v) avv g1 ->
      evalma g (p, tmtc ((pp,m,(dp,n)::sv)::pv),d,v) avv g1
  | ema : forall p Gamma pp M dp N sigma_v Delta v phi1_v p1_v Delta1 Delta2 v1_vv m1_v,
      evalpp pp Gamma p (Some (p1_v, Delta1)) ->
      evaldp dp v (Some Delta2) ->
      eval (Delta @@ Delta1 @@ Delta2, N, tcll v1_vv) ->
      evalmtc Gamma M m1_v ->
      evalma Gamma (p, tmtc ((pp,M,(dp,N)::sigma_v)::phi1_v), Delta, v)
             ((map (fun tpl => match tpl with
                              | (ttpl v11 v12) => map (fun t => let '(v1, (m1, Gamma1), p1) := t in (p1,m1,Gamma1,v1)) (zip3 [v11;v12] m1_v p1_v)
                              | v11 => map (fun t => let '(v1, (m1, Gamma1), p1) := t in (p1,m1,Gamma1,v1)) (zip3 [v11] m1_v p1_v)
                              end
                   ) v1_vv)) empty.

最も重要なのはeval (Gamma, M, v)で「環境Gammaのもとで項Mが値vに評価される」という命題です。例えば、冒頭のプログラム中の(match-all [1 2] unordered-pair {[<pair $a $b> [a b]]}){[1 2] [2 1]}に評価される、というのは次のようなCoqの命題として表すことができます。

Definition unordered_pair: tm :=
    (tmtc [(pppair ppdol ppdol, ttplmtc tsm tsm,
            [(dppair (dpvar "x") (dpvar "y"), (tcll [(ttpl (tvar "x") (tvar "y")); ttpl (tvar "y") (tvar "x")]))]);
           (ppdol, tsm,
            [(dpvar "tgt", tcll [tvar "tgt"])])]).

Definition match_all_example: tm :=
    (tmal (tpair (tint 1) (tint 2)) unordered_pair (ppair (pvar "a") (pvar "b"),ttpl (tvar "a") (tvar "b"))).

Theorem unordered_pair_example : eval (empty, match_all_example, tcll [ttpl (tint 1) (tint 2);ttpl (tint 2) (tint 1)]).

また、Egisonの操作的意味論上でこの評価結果が正しいということも証明できます。 つまりunordered_pair_exampleを証明することができます。証明は長くなるので省略しますが、Egison.vの230行目以降にあります。

今後の課題

動機の項で述べたようにformalized-egisonの最終的な目標はEgisonの型安全性の証明です。Coqで表現すると

Theorem type_soundness:
    forall Gamma M T, is_typed Gamma M T => exists v, eval Gamma M v /\ is_typed Gamma v T.

となります。ただしis_typed Gamma M Tは環境Gammaの下で項Mに型Tがつくと読みます。 今後はtyped-egisonの型付け規則をCoqで書き直してis_typedを定義した上でtype_soundnessを証明する予定です。

MNISTを可能な限り高速に分類する

概要

  • MNISTの分類をする学習済みモデルを軽量化し、
  • 更にSIMD命令を使った高速化を行うことで、
  • onnxruntimeより高速なMNISTの分類が可能になりました。(シングルスレッドで)

前回までのあらすじ

この記事は前回の続きです。 前回はMNISTを分類する学習済みニューラルネットワークから不要な枝を削除し、軽量化した学習済みモデルを走らせる専用のランタイムsonnxを作ってMNIST分類の高速化を試みました。 今回はSIMD命令とマルチスレッド化による最適化でMNIST分類速度の限界に挑みます。

今回の目標タイムは既存のONNXランタイムonnxruntimeです。 この記事の各実行時間の計測は10回行い、その平均と分散を求めました。 onnxruntimeと最適化無しのsonnxの実行時間は、OS: Ubuntu19.04、CPU: Core i7 8th Gen、 メモリ: 16GiB、GPU無しの環境下で以下のようになりました。1

手法 実行時間
onnxruntime(シングルスレッド) 1.259秒 (標準偏差 0.1148秒)
onnxruntime(マルチスレッド) 0.505秒 (標準偏差 0.04249秒)
sonnx(最適化無し) 6.968秒 (標準偏差 0.08912秒)

機械学習の推論過程、特にMNISTを分類する場合、においてボトルネックになるのは重み行列を入力ベクトルに乗算する処理です。 前回は重み行列の中で絶対値の小さい要素を無視することで計算の効率化を図り、行列の要素の80%を無視して約5倍の高速化に成功しました。 しかし、80%を削減してもなお乗算はボトルネックでした。

具体的にはsonnx.cppのこの部分がボトルネックになります。

for(int i=0;i<n;i++){
        ret[B_row[i]] += B_scale[i] * x[B_column[i]];
}

SIMDによる高速化

まず、sonnxをSIMDで高速化してみます。 Core i7 8th GenではAVX2命令セットが使えるので256bitの演算を一度に行うことができ、 今回は32bit浮動小数点数で計算しているので最大8倍の高速化が見込めます。

しかし、元のコードはメモリへの間接参照を2つ(ret[B_row[i]]x[B_column[i]])含んでおりそのままではSIMD化するのが困難です。 まずret[B_row[i]]への書き込みは同じB_row[i]の値を持つものをまとめて計算し、メモリへの書き込みを一度に行います。 x[B_column[i]]からの読み出しはAVX2のgather命令を使ってSIMD化します。 完成したコードがこれです。

float r = 0;
__m256 vr = _mm256_setzero_ps();
for(cur;cur<n1;cur+=8){
    __m256i vc = _mm256_loadu_si256((__m256i*)(&B_column_p[cur]));
    __m256 vx = _mm256_i32gather_ps(x_p, vc, 4);
    __m256 vs = _mm256_load_ps(&B_scale_p[cur]);
    vr = _mm256_fmadd_ps(vs, vx, vr);
}
for(cur;cur<n;cur++){
    r += B_scale_p[cur] * x_p[B_column_p[cur]];
}
__attribute__((aligned(32))) float t[8] = {0};
_mm256_store_ps(t, vr);
ret[B_row_p[cur-1]] += r + t[0] + t[1] + t[2] + t[3] + t[4] + t[5] + t[6] + t[7];

実行時間はこんな感じになりました。

手法 実行時間
sonnx(SIMD) 1.121秒(標準偏差 0.02271秒)
onnxruntime(シングルスレッド) 1.259秒(標準偏差 0.1148秒)

ウェルチのt検定を用いて検定すると2有意水準5%で帰無仮説が棄却され、確かにonnxruntimeより高速に推論できています。 onnxruntime(シングルスレッド)と大きな差がつかないのはgather命令が遅いからでしょうか。

マルチスレッドによる高速化

(SIMDなしの)マルチスレッド化はどうでしょうか。 オリジナルのコードはメモリへの書き込みが一箇所(ret[B_row[i]]への書き込み)しかないので、この部分だけ複数スレッドで同時に書き込まないようにすればデータレースは起こりません。 スレッド数は手元で最適なものを探索した結果4にしています。 ソースコードここにあります。

void CompressedGemm::calc_partially(const int index, const vector<float> &x){
    int m = B_nrows_threads[index].size();
    int cur = 0;
    for(int i=0;i<m;i++){
        int n = B_nrows_threads[index][i];
        float r = 0;
        for(int j=cur;j<cur+n;j++){
            r += B_scale_threads[index][j] * x[B_column_threads[index][j]];
        }
        ret[B_row_threads[index][cur]] += r;
        cur += n;
    }
}

vector<float> CompressedGemm::calc(const vector<float> &x){
    ret = C;
    vector<thread> ths;
    
    for(int i=0;i<n_thread;i++){
        ths.push_back(thread(&CompressedGemm::calc_partially, this, i, x));
    }
    for(int i=0;i<n_thread;i++){
        ths[i].join();
    }
    return ret;
}

結果はこんな感じです。

手法 実行時間
sonnx(マルチスレッド) 1.995秒(標準偏差 0.03405秒)
onnxruntime(マルチスレッド) 0.5048秒(標準偏差 0.04249秒)

完敗です。

SIMDとマルチスレッドの併用

では最後にSIMDとマルチスレッディングを併用してみます。 ソースコードここにあります。

手法 実行時間
sonnx(SIMD+マルチスレッド) 1.358秒(標準偏差 0.05762秒)
onnxruntime(マルチスレッド) 0.5048秒(標準偏差 0.04249秒)

あまり効果がありませんね...複数コアからの結果をまとめるのに時間がかかっているのかもしれません。

まとめ

影響力の小さい要素を削除して学習済みのモデルを軽量化した上でSIMDとマルチスレッド化を用いて推論の高速化を試みました。 マルチスレッド環境ではonnxruntimeに勝てませんでしたが、シングルスレッドではonnxruntimeより高速に推論できました。

参考


  1. 前回からOSを入れ替えたので数値が違います。

  2. ちょっとここ正しく議論できているか自信がない