エムスリーテックブログ

エムスリー(m3)のエンジニア・開発メンバーによる技術ブログです

GPUで高速なモデル推論を実現するために考えること -FlashAttentionはなぜ高速か-

こちらはエムスリー Advent Calendar 2024 17日目の記事です。

AI・機械学習チームの髙橋です。チームでは先週からNeurIPS読み会が開催されており、"Deep Learning Architecture, Infrastructure"という深層学習のアーキテクチャに関するセッションを担当しました。その中でも興味深い一本として"You Only Cache Once: Decoder-Decoder Architectures for Language Models"という論文を勉強会まとめブログで紹介してます。 www.m3tech.blog この論文ではLLMの推論時に用いられるKey Value Cacheという仕組みに着目してGPUメモリの占有量やスループットを大幅に向上させるアーキテクチャを提案していました。非常にモデルが大規模化した昨今、この論文のようにGPUが持つハードウェア特性に応じてモデルのアーキテクチャを改良することがより一層求められています。

この類の技術で有名なものとしてFlashAttentionというLLMのAttention機構の推論を高速化するアルゴリズムがあり、論文では最大6倍程度Attention機構の推論速度が向上することが報告されています。しかし、FlashAttentionは行列計算の演算量を削減するわけではありません。それなのになぜこれほど大幅な高速化が達成できたのでしょうか。

答えはFlashAttentionの論文タイトルにもなっているとおり"Fast and Memory-Efficient Exact Attention with IO-Awareness"、I/Oが大幅な速度向上の鍵です。 ではなぜこれほどGPUでのI/OがAttention機構の計算でクリティカルなのか、今回は行列演算とI/O、そのバランスに着目して実際の計算量を見積もりしつつGPU上でのモデル推論速度を考えていきます。

モデル推論速度を決める2大要素

さて、LLMを始めとする深層学習モデルのアーキテクチャを考えるときには常に性能と推論速度のバランスを考えることになります。 ここでは性能は一旦置いておいて、推論速度を改善したいとき最もよく使われる指標はモデルの演算回数(FLOPs)です。

しかし、直感に反して必ずしもGPUでの推論速度は演算回数に比例しません。LLM以外、例えば画像分類モデルでも、MobileNetで提案されたDepth-Wise Convolutionを用いると計算量は何分の一かに削減し、CPU上では実際にそれ相応に推論速度が向上しますが、GPU上での推論速度は実はそれほど速くなりません。

このことは例えばDesigning Network Design Spacesが指摘していて、GPUでの推論時間はFLOPsよりモデルが推論時に中間レイヤーで生成するActivationの総量のほうがより強い相関を持つことが示されています。

なぜこうなるかといえばGPUのように高速な行列演算を実現しているハードウェアではメモリのI/O、データ転送速度のほうがボトルネックになることが度々生じるためです。

GPUに搭載される2種類のメモリ

GPUには大きく分けて2種類のメモリ、HBM(High Bandwidth Memory)とSRAM (Static Random Access Memory)が搭載されています。HBMはいわゆるGPUメモリと呼ばれているもので、例えばNVIDIA GPU H100は141GBのHBMがあり、LLMの推論時にはモデルのパラメータやアクティベーションのキャッシュを保存します。一方でSRAMはGPUで実際に演算を行うためのデータを一時的に保持するためのメモリ空間です。つまり演算を行うにはHBMからSRAMにデータを送り、その演算結果をSRAMからHBMに書き戻す処理が発生します。

SRAMは非常に高価な代物で、256KBのアクセスタイムが非常に高速なSRAM*1だと1個あたり1000円以上で売られています(リンクから検索)。NVIDIAがどのスペックのSRAMをいくらで仕入れているかはわかりませんが、H100のようなGPU1基作るのにSRAMメモリのみで数十万円(SMに加えてL2 Cache用に数十MB)を必要とします。また半導体製造技術の観点からもSRAMの物理サイズを小さくすることは難しく、いくらお金を積んでも大量のSRAMを搭載したGPUを製造することは現状では困難なようです。

したがって特に昨今の大規模化したLLMの学習/推論ではいかにして限られたHBM↔SRAMのデータ転送帯域を効率的に使うかが推論速度を決定することになります。

簡略化したモデルでTransformerの計算量を考える

ここではGPUを1. 行列演算器 2. ベクトル演算器 3. HBM 4. SRAMの4要素だけ構成されると仮定して、LLMの処理の一部を切り出して行列演算とHBM↔SRAM間のデータ転送をモデル化してみましょう。

H100を簡略化したモデルの図。演算速度は1000TFLOPs、HBM↔SRAM間の転送速度は1.5T要素/secと仮定する。

ベクトル演算器はSoftmax関数など行列積以外の処理を行う演算器ですが、LLMで生じるべクトル演算は行列積と比べて微々たるものなので計算量の推計からは除外します。

H100の公称では、FP16での演算処理速度は約1000TFLOPs、一方でSRAM↔HRAMのデータ転送の速度は約3TByte/secです。(すべてFP16でデータ転送すると仮定すると1.5T要素/sec) したがって、演算回数がデータ転送量の1000/1.5=660倍程度だと釣り合い、これよりデータ転送量が多くなる(比が660倍より小さい値になる)とI/Oのほうが支配的になると理解できます。*2

QKV projectionのケース

LLMのEmbedding入力をQuery, Key, ValueにProjection(QKV projection)する処理を例にして考えてみましょう。

QKV projectionのアルゴリズムと計算量。tex記法が崩れてしまったため画像化している。

各記号は以下の数値を表しています。

  • B: バッチサイズ
  • S: 系列長
  • E: Embedding次元
  • H: Head次元 (head_dim × head_num)

HBM↔SRAMのデータのやりとりを図にしてみるとこのようになります。

QKV ProjectionでのHBM↔SRAM間のデータ移動

QKV projectionでは演算回数はデータ転送量の1600倍となり、よって行列演算のほうがボトルネックになります。*3

Attention機構の場合

さきほどはQKV projectionを例にしましたが次にAttention機構を同様に考えてみましょう。冒頭で述べたFlashAttentionなどを使わずPyTorchでナイーブに実装するとAttentionは以下のアルゴリズムで計算され計算量は次のとおりになります。

通常のAttentionアルゴリズムと計算量

各記号は以下の数値を表しています。

  • B: バッチサイズ
  • S: 系列長
  • D: head_dim次元

QKV projectionのときと同様にHBM↔SRAMのデータ移動を図式化すると次のようになります。

AttentionでのHBM↔SRAM間のデータ移動、Attention Map Pを計算するステップ6まで

AttentionでのHBM↔SRAM間のデータ移動、ステップ7から9まで

演算回数はデータ転送量の90倍にしかならず、I/Oによって律速され大幅に推論速度が低下することになります。

なぜI/Oが支配的か

これには大きく分けて2つの理由があります。

1つは、Attentionの演算/データ転送コストが共に系列長Sの2乗head_dim次元Dに比例し、S>Dであることです。結果として、係数を除くと演算/データ転送コストはhead_dim次元Dの分しか差が出ずI/Oが支配的になります。

もう1つは、アルゴリズムに冗長なHBM↔SRAMの転送が含まれている点です。ステップ4,6ではSoftmaxを実行するためだけに(B, S, S)の転送が生じています。また、ステップ5では一度SRAMに置いてあったPをステップ7でもう一度HBMから転送しています。

FlashAttentionによる大幅なI/O改善

FlashAttentionはAttention機構の計算がI/Oによって律速されていることに着目して提案された手法です。通常のAttentionとほぼ等価な計算結果が得られ推論が高速化されるため、PyTorch2.xでは標準利用できるなど現在では非常に普及した技術となっています。

I/O改善を目的として、FlashAttentionは以下の方針で導出されたアルゴリズムです。

  1. Sequence方向にQKVをブロック分けする
  2. ブロックごとにAttentionの最終出力まで計算する
  3. ブロック間の依存関係は動的計画法的に更新していく

簡単に補足すると、1のブロック化はGPUの各SMに搭載されるSRAMの容量は限られているためで、ブロック化することでSRAMの容量内でブロックごとの処理をできるようにしています。 2はブロックごとにSoftmaxに相当する部分を計算してそのブロックの出力Oにあたる部分まで計算するため、通常のAttention処理で生じたSoftmaxを計算するためのデータ転送を省くことができます。

実際の処理を書き下すことは省きますが、演算回数とデータ転送量はそれぞれ以下のとおりになります。

FlashAttentionの計算量

ここで新たに登場したkはGPUに搭載されているSMあたりのSRAMの容量に依存しているハイパーパラメータ(SRAM容量が大きければ大きいほどブロックサイズは大きくできる)で論文中ではA100 GPUでk=256を採用していました。

通常のAttention処理と比較すると演算回数/SRAM↔HRAMのデータ転送量はおおよそ3倍になっておりI/O律速を大幅に改善できていることが分かります。

これはブロック化によってこれまでSの2乗に比例していたデータ転送量の項がD/k分削減できること、加えてXやPなどを一々HBMに書き戻さなくてよくなっているためです。

まとめ

GPUで高速なモデル推論を実現するにはモデルの演算量のみではなくHBM↔SRAM間のデータ転送量にも着目することが大事です。

特に近年はモデルが大規模化したことでハードウェア側の制約に着目した技術の重要度がますます高まっていきそうです。こうした研究や技術を引き続き追っていきたいと思います。

We are hiring !!

エムスリーAI・機械学習チームではLLMや臨床AIに取り組んでいます。モデルのアーキテクチャや計算アルゴリズムにワクワクする皆さん、エムスリーAI・機械学習チームで一緒に機械学習エンジニアやりましょう! また、学生の皆さん向けには機械学習・MLOpsインターンも募集してます。ぜひ一緒に論文を読みサービス開発していきましょう。

エムスリーでは、コンピュータビジョン・機械学習はもちろん、最新技術へのアンテナが高い仲間を歓迎しています。新卒・中途それぞれの採用、カジュアル面談やインターンも常時募集しています!

エンジニア採用ページはこちら

jobs.m3.com

カジュアル面談もお気軽にどうぞ

jobs.m3.com

インターンも常時募集しています

open.talentio.com

*1:H100だと132個のStreaming Multiprocessor(SM)に1つずつ256KBのSRAMが搭載されています。

*2:現実にこの演算/データ転送速度を使いきれることはありませんが、ここでは理想化して考えます。

*3:最近のGPU性能は行列演算性能のほうがデータ転送速度よりも速いペースで向上しているので、将来的にはこれも逆転するかもしれません。