指数関数expのAVX-512によるベクトル化

初めに

サイボウズ・ラボの光成です。

C++で単精度配列に対する指数関数のベクトル化をAVX-512を使って実装しました。 標準関数std::exp(float)に対する相対誤差は2e-6、速度は10倍ぐらいです。 指数関数をどうやって計算するのか、一般的な話とAVX-512に特有の部分を紹介します。

想定読者

C++とx64(x86-64)のアセンブリ言語の知識を多少仮定しますが、 なるべく少ない前提知識で読めるように心がけます。 ある程度知識のある方は近似計算から読み始めてかまいません。

実行環境

AVX-512が使える環境とx64用C++コンパイラが必要です。

コードはherumi/fmathにあります。 コンパイルにxbyakが必要なのでダウンロードして適宜includeパスを指定してください。

fmath2名前空間で

void expf_v(float *dst, const float *src, size_t n);

が定義されています。n個のfloatの配列srcのexpをdstに格納します。

for (size_t i = 0; i < n; i++) {
  dst[i] = std::exp(src[i]);
}

と(誤差を除いて)等価です。

ベンチマーク

速度

ベンチマークはexp_v.cppで行いました。

float x[3000];に対してexpを求める計算をstd::expとexpf_vとで比較しました。

環境はOS : Ubuntu 19.10, CPU : Xeon Platinum 8280 2.7GHz, compiler : gcc-9.2.1 -Ofastです。

関数 std::exp expf_v
時間(clk) 22.6K 1.8K

clkはrdtscによるCPUクロックの計測で、概ね10倍以上高速化されています。

誤差

相対誤差を(真の値 - 実装値) / 真の値とします。 std::exp(float x)を真の値としてx = -30から30まで1e-5ずつ増やして計算した値の相対誤差の平均を出すと2e-6となりました。

exp(x)の性質

指数関数exp(x)はe = 2.71828...の巾乗exp(x) = exという関数です。

数学的には

exp(x) = 1 + x + x2/2! + x3/3! + ... + xn/n! + ...

と定義されています。 xはマイナス無限大からプラス無限大までとりえます。

ここで!は階乗の記号で4! = 4 * 3 * 2 * 1です。 定義は無限個の値の和ですが、コンピュータではもちろん途中で打ち切って有限和で近似計算します。

xの絶対値が1より小さいときはxnはとても小さく、更にn!で割るのでもっと小さくなります。 したがって少ない和で打ち切っても誤差は小さくてすみます。 しかしxが1より大きいとxnはとても大きくなり、打ち切り誤差が大きくなります。 どのようにして誤差を小さくするかがポイントです。

y = axの逆関数をx = log_a(y)と書きます。 特にa = eのときlog_e(y) = log(y)と省略します。

  • ax+y = ax ay
  • log(xy) = y log(x)
  • xy = zy log_z(x) ; 底の変換公式

などが成り立ちます。

計算の範囲

まずxのとり得る範囲を調べましょう。 xの型はfloatです(ここではx64が対象なのでfloatは32bit浮動小数点数とします)。 調べてみるとx = -87.3より小さいとfloatで正しく扱える最小の数FLT_MIN=1.17e-38より小さく、 逆にx = 88.72より大きいと最大の数FLT_MAX=3.4e38よりも大きくなってinfになります。 従ってxは-87.3 <= x <= 88.72としてよいでしょう。

近似計算

2の整数巾乗2nはビットシフトを使って高速に計算できます。 したがってexp(x) = exから2の整数巾を作り出すことを考えます。

底の変換公式を使って

ex = 2x log_2(e)

と変形し、x'=x log_2(e)を整数部分nと小数部分aに分割します。

x' = n + a (|a| <= 0.5)

そうするとex = 2n × 2aです。

2nはビットシフトで計算できるので残りは2aの計算です。 ここで再度底の変換をします。

2a = ea log(2) = eb ; b = a log(2)とおく

|a| <= 0.5でlog(2) = 0.693なので|b| = |a log(2)| <= 0.346.

bが0に近い値なのでebを冒頭の級数展開を使って近似計算します。

6次の項は0.3466/6! = 2.4e-6とfloatの分解能に近いので5次で切りましょう。

eb = 1 + b + b2/2! + b3/3! + b4/4! + b5/5!

アルゴリズム

結局次のアルゴリズムを採用します。

input : x
output : e^x

1. x = max(min(x, expMax), expMin)
2. x = x * log_2(e)
3. n = round(x) ; 四捨五入
4. a = x - n
5. b = a * log(2)
6. z = 1 + b(1 + b(1/2! + b(1/3! + b(1/4! + b/5!))))
7. w = 1 << n
8. return z * w

意外と短いですね。

AVX-512での実装

AVX-512はint32_t, int8_t, double, floatなど様々な型のをまとめて処理する命令セットの名前です。 一つのレジスタが512ビットあるのでfloatなら512/32 = 16個まとめて処理できます。 レジスタはzmm0からzmm31まで32個利用できます。

AVX-512の命令概略

この記事に登場するAVX-512用の命令をまとめておきます。 詳細はIntel64 and IA-32 Architectures Software Developer Manualsを参照ください。

アセンブリ言語の表記はIntel形式で、オペランドはdst, src1, src2の順序です。 dst, srcはdst, dst, srcの省略記法です。

命令 意味 注釈
vmovaps [mem], zmm0
vmovaps zmm0, [mem]
[mem] = zmm0
zmm0 = [mem]
memは16の倍数のメモリアドレスであること
vmovups [mem], zmm0
vmovaps zmm0, [mem]
[mem] = zmm0
zmm0 = [mem]
memの制約はない
vaddps zmm0, zmm1, zmm2 zmm0 = zmm1 + zmm2 floatとして
vsubps, vmulpsなども同様
vminps zmm0, zmm1, zmm2 zmm0 = min(zmm1, zmm2) floatとして
vpaddd zmm0, zmm1, zmm2 zmm0 = zmm1 + zmm2 uint32_tとして
vpslld zmm0, zmm1, imm zmm0 = zmm1 << imm uint32_tとして
vfmadd213ps zmm0, zmm1, zmm2 zmm0 = zmm0 * zmm1 + zmm2 floatとして
vpbroadcastd zmm0, eax eaxを16個分zmm0にコピーする
vcvtps2dq zmm0, zmm1 zmm0 = round(zmm0) 結果はint型
vcvtdq2ps zmm0, zmm1 zmm0 = float(zmm0) 結果はfloat型
vrndscaleps zmm0, zmm1, 0 zmm0 = round(zmm1) 結果はfloat型

初期化

// exp_v(float *dst, const float *src, size_t n);
void genExp(const Xbyak::Label& expDataL)
{
    const int keepRegN = 7;
    using namespace Xbyak;
    util::StackFrame sf(this, 3, util::UseRCX, 64 * keepRegN);

StackFrameは関数のプロローグを生成するクラスです。 3は引数が3個、UseRCXはrcxレジスタを明示的に使う指定、 zmmレジスタの保存のため64 * keepRegN byteスタックを確保します。

    const Reg64& dst = sf.p[0];
    const Reg64& src = sf.p[1];
    const Reg64& n = sf.p[2];

StackFrameクラスのsf.p[i]で関数の引数のi番目のレジスタを表します。 WindowsとLinuxとで引数のレジスタが異なるのでここで吸収します。

    // prolog
#ifdef XBYAK64_WIN
    vmovups(ptr[rsp + 64 * 0], zm6);
    vmovups(ptr[rsp + 64 * 1], zm7);
#endif
    for (int i = 2; i < keepRegN; i++) {
        vmovups(ptr[rsp + 64 * i], Zmm(i + 6));
    }

AVX-512のZmmレジスタを保存します。 関数内でWindowsではzmm6以降、Linuxではzmm8以降を利用する場合は保存する必要があります。

    // setup constant
    const Zmm& i127 = zmm3;
    const Zmm& expMin = zmm4;
    const Zmm& expMax = zmm5;
    const Zmm& log2 = zmm6;
    const Zmm& log2_e = zmm7;
    const Zmm expCoeff[] = { zmm8, zmm9, zmm10, zmm11, zmm12 };
    mov(eax, 127);
    vpbroadcastd(i127, eax);
    vpbroadcastd(expMin, ptr[rip + expDataL + (int)offsetof(ConstVar, expMin)]);
        ...

各種定数をレジスタにセットします。 vpbroadcastdはfloat変数1個を32個Zmmレジスタにコピーする命令です。

    vpbroadcastd(expMin, ptr[rip + expDataL + (int)offsetof(ConstVar, expMin)]);

はXbyak特有の書き方です。 LabelクラスexpDataLは各種定数(ConstVarクラス)が置かれている先頭アドレスを指します。 ripで相対アドレスを利用し、Cのoffsetofマクロでクラスメンバのオフセット値を加算します。

"rip相対アドレス"
rip相対アクセス

メインループ

vminps(zm0, expMax); // x = min(x, expMax)
vmaxps(zm0, expMin); // x = max(x, expMin)
vmulps(zm0, log2_e); // x *= log_2(e)
vcvtps2dq(zm1, zm0); // zm1 = n = round(zm0)
vcvtdq2ps(zm2, zm1); // zm2 = float(zm1)
vsubps(zm0, zm2); // a = x - n
vmulps(zm0, log2);   // a *= log2

アルゴリズムの1から5行目に対応します。 vminps, vmaxpsで入力値を[expMin, expMax]の範囲内にクリッピングします。 vmulpsでlog_2(e)倍しvcvtps2dqで整数へ最近似丸め(round)します。 結果はint型になるのでそれをvcvtdq2psでfloat型に戻します。

vmovaps(zm2, expCoeff[4]); // 1/5!
vfmadd213ps(zm2, zm0, expCoeff[3]); // b * (1/5!) + 1/4!
vfmadd213ps(zm2, zm0, expCoeff[2]); // b(b/5! + 1/4!) + 1/3!
vfmadd213ps(zm2, zm0, expCoeff[1]); // b(b(b/5! + 1/4!) + 1/3!) + 1/2!
vfmadd213ps(zm2, zm0, expCoeff[0]); // b(b(b(b/5! + 1/4!) + 1/3!) + 1/2!) + 1
vfmadd213ps(zm2, zm0, expCoeff[0]); // b(b(b(b(b/5! + 1/4!) + 1/3!) + 1/2!) + 1) + 1

アルゴリズムの6行目に対応します。 vfmadd213ps(x, y, z)は積和演算命令で、x = x * y + zを実行します。

// zm1 = n
vpaddd(zm1, zm1, i127);
vpslld(zm1, zm1, 23); // 2^n

アルゴリズムの7行目に対応します。 ここはちょっとわかりにくいので次節で解説します。

vmulps(zm0, zm2, zm1);

アルゴリズムの8行目に対応します。 これでexp(x)の計算が終了です。

floatのフォーマット

アルゴリズムの7行目を実装するためにfloatのフォーマットの説明をします。 floatは符号ビットs、指数部e、仮数部fからなります。 それぞれsが1ビット、eが8ビット、fが23ビットで合計32ビットです。

役割 符号 指数部 仮数部
記号 s e f
ビット数 1 8 23

ビットパターンが[s:e:f]で表されるfloatは(-1)^s × 2^(e-127) × (1 + f/2^23)という値を表します。

たとえばx = 0なら0 = (-1)0 × 20 × (1 + 0)とかけるので 符号は0(0または正)、指数部はe = 127、仮数部f = 0です。 逆にs = 1, e = 130, f = 0x123456で表されるビットパターン[s:e:f]=0xc1123456はfloatとして- 2130-127 × (1+f/223) = -9.137です。

前節では整数nに対してfloat(2n)が欲しかったのでした。 これに対応するビットパターンはs = 0, e = n + 127, f = 0です。 つまり((n + 127) << 23)という32ビット整数がfloatの2nを表すのです。

// zm1 = n
vpaddd(zm1, zm1, i127);
vpslld(zm1, zm1, 23); // 2^n

したがってvpadddでintの127を足し、vpslldで左23ビットシフトすることで必要な値を得られます。

floatからintへの変換

floatをintに丸める方法はいくつかあります。 今回はSSEの時代からある変換命令vcvtps2dqを使いました。 これは丸め方法がグローバルな設定に依存します。 通常モードを変更することはありませんが、もし別の設定を使うことがあるならこの方法は使えません。

intへの切り捨て専用命令vcvttps2dqというのもあります。この場合は0.5を足してから切り捨てれば四捨五入となります。 しかし負の場合は0.5を引く必要があり、やや複雑になります。

次にvroundpsという丸めモードを設定して使える命令があります。 しかしこの命令はAVX2までで何故かAVX-512用に拡張されていません。

代わりに追加されたvrndscalepsは丸めモードを自分で設定できます。 結果はfloat型になるのでintにするにはvcvtps2dqが必要です。 今回のアルゴリズムは整数にした後、floatとintの両方の型の値が必要だったのでレイテンシの短いvcvtps2dqを利用しました。

端数処理

floatを16個ずつ処理すると元の配列の個数nが16の倍数でないとき端数が出ます。 その処理方法について解説します。

AVX2までのSIMD命令では端数処理が苦手でした。 命令が16単位なので残り5個をレジスタに読み込むといった処理がやりにくいのです。 そのためSIMD命令を使わない通常の方法でループを回す方法をとることが多いです。

AVX-512ではそれを解決するためのマスクレジスタk1, ..., k7が登場しています。 マスクレジスタは各ビットがデータ処理する(1)かしない(0)かを指定するレジスタです。 データ処理しない場合は更にゼロで埋める(T_z)か値を変更しないかを選択できます。

たとえば

vmovups(zmm0|k1|T_z, ptr[src]); // zmm0 = *src;

でk1 = 0b11111;の場合、下位5ビットが立っているのでfloat *srcのsrc[0], ..., src[4]だけがzmm0にコピーされ、T_zを指定しているので残りはゼロで埋められます。

実装コードでの解説に戻ります。

and_(ecx, 15); // ecx = n % 16
mov(eax, 1); // eax = 1
shl(eax, cl);  // eax = 1 << n
sub(eax, 1);   // eax = (1 << n) - 1 ; nビットのmask
kmovd(k1, eax); // マスク設定

ecxにループの回数nが入っているとき15とandをとり端数を得ます。 ((1 << n) - 1)はデータが入っていない部分が0となるマスクです(n = 3なら0b111となる)。

vmovups(zm0|k1|T_z, ptr[src]);

vmovups(zm0, ptr[src])はsrcからzm0に16個のfloatを読む命令ですが、 vmovups(zm0|k1|T_z, ptr[src])とk1でマスクすると指定したビットが立った部分しかメモリにアクセスしません。

ここで重要な点は内部的に512ビット読み込んでからゼロにするのではなく マスクされていない領域にread/write属性が無くても例外が発生しないという点です。

安心してページ境界にアクセスできます。

係数の決め方

最後にアルゴリズムの6行目

6. z = 1 + b(1 + b(1/2! + b(1/3! + b(1/4! + b/5!))))

の値の改善方法について紹介します。 この式は無限に続く和を途中で打ち切ったものでした。 したがって必ず正しい値よりも小さくなります。

1/k!の値を微調整することで誤差をより小さく出来ます。

bのとり得る範囲はL = log(2)/2として[-L, L]でした。 関数f(x) = 1 + A + Bx + Cx2 + Dx3 + Ex4 + Fx5として 区間[-L, L]でf(x)とexp(x)の差の2乗誤差の平均を最小化する(A, B, C, D, E, F)を見つけます。

数学的にはI(A, B, C, D, E, F):=∫_[-L,L](exp(x) - f(x))2 dxとして IをA, B, C, D, E, Fで偏微分した値が全て0になる解を求めます。

Mapleでは

f := x->A+B*x+C*x^2+D*x^3+E*x^4+F*x^5;
g:=int((f(x)-exp(x))^2,x=-L..L);
sols:=solve({diff(g,A)=0,diff(g,B)=0,diff(g,C)=0,diff(g,D)=0,diff(g,E)=0,diff(g,F)=0},{A,B,C,D,E,F});
Digits:=1000;
s:=eval(sols,L=log(2)/2);
evalf(s,20);

で求めました。 雑な比較ですが単純に打ち切ったときに比べて誤差が半分程度になりました。

Sollyaを使って

remez(exp(x),5,[-log(2)/2,log(2)/2]);

で求めるやり方もあります(個人的には2乗誤差を小さくする前者の方がよい印象 : 数値計算専門の方教えてください)。

まとめ

exp(x)の近似計算の方法とAVX-512特有の命令の紹介をしました。 端数処理をうまくできるマスクレジスタは便利ですね。 昔に比べてSIMDレジスタの幅が大きくなっているのでテーブル引きをせずに計算した方が速くなることが多いようです。