Hugging Face、PyTorchの推論最適化を解説

nn.Linearの実態

転置はメタデータ書換のみ
バイアス加算はGEMM融合
addmmが単一カーネル化
compileは融合余地なし

MLPの融合効果

MLP1回で5カーネル
compileがGeLUとmulを融合
中間テンソルのHBM往復削減

手書きカーネル

Ligerは形状非依存で再コンパイル不要
詳細を読む

Hugging Faceは6月11日、PyTorchの処理を可視化するプロファイリング連載の第2回を公開しました。今回は深層学習の基本部品であるnn.Linearを題材に、GPUカーネルの実際の挙動を追い、torch.compileやLiger製の手書きカーネルとの違いを実測値で示しています。対象読者はモデルの推論速度を詰めたいエンジニアです。

まず単一のnn.Linearでは、行列積と転置、バイアス加算が一見すると別々の処理に見えますが、実態は異なります。転置を担うaten::tはGPU上でカーネルを起動せず、テンソルの形状とストライドというメタデータを書き換えるだけです。バイアス加算もcuBLASのGEMMカーネル末尾に折り込む「エピローグ」として統合され、最終的にaten::addmmという単一カーネルで完結します。

そのため単一のLinearではtorch.compileが融合する余地はほぼ残っていません。compileが消すのはGPUの計算ではなく、ビュー処理を発行するCPU側の数マイクロ秒のオーバーヘッドです。Inductorがコンパイル時にストライドを計算し、addmmを直接呼び出すよう書き換えるため、GPUの計算内容は変わりません。

効果が表れるのは三つのLinearを積んだMLPです。GeGLU構成のMLPは1回の順伝播で3つのGEMMとGeLU、乗算の計5カーネルを起動します。torch.compileはこのうちGeLUと乗算、リシェイプを1つのTriton融合カーネルにまとめ、約50MBの中間テンソルがHBMを往復する無駄を排除します。これがコンパイルによる最大の改善点です。

記事は最後に、人手で調整したLiger製カーネルを比較対象に挙げます。Ligerの実行時間は92.8マイクロ秒で、特定形状向けに最適化されたInductorの89.4マイクロ秒よりわずかに遅く見えます。しかしInductorは入力形状が変わるたびに再トレースとコンパイルが必要で、Ligerは形状が変わっても再コンパイル不要です。数マイクロ秒と引き換えに形状変化への頑健さを得ているわけです。

筆者が一貫して勧めるのは「先に予想し、それから見る」という習慣です。トレースを開く前にカーネル数や種類を予測し、想定と食い違った点こそ最も学びが多いと説きます。次回はMLPからアテンション、最終的には完全なモデルへと解説を進める予定です。