こちらはエムスリー Advent Calendar 2024 5日目の記事です。AI・機械学習チームの三浦 (@mamo3gr) が、深層学習に基づく医療AIの開発環境とTipsについてお送りします。
- Kubernetesクラスタよる医療AIの開発
- モデルの性能改善に注力できるフレームワークPyTorch Lightning
- ハマりポイントとワークアラウンド
- まとめ
- We are hiring !!
Kubernetesクラスタよる医療AIの開発
AI・機械学習チームでは、臨床現場で使われる、いわゆる医療AIの開発に取り組んでいます。 これまでに一例として、胸部X線画像から肺動脈性肺高血圧症 (PAH) や間質性肺炎を検出するAIを開発し、 国際会議での発表や論文誌への採録だけでなく、 医療従事者を支えるソリューションとして医療施設で実際に活用されています。 それぞれの医療AIの詳細については、例えば次の記事を参照ください*1。
これら医療AIは深層学習に基づくものが大半で、開発効率の観点からKubernetesを利用しています。 チームでは、レコメンドエンジン・検索エンジンの開発や、 マーケティング向けデータ分析など幅広いプロダクト・プロジェクトを手掛けており、 cookiecutterによるリポジトリ立ち上げから、 クラスタへのAPI・バッチのデプロイまでが一気通貫で型化されています*2。 医療AIの開発においても、(GPUが必要なこと以外は)これらのプロダクトと同じに扱うことができ、 チーム共有のテンプレートやノウハウを活用することで大幅に効率化できています。 また、Kubernetesのメリットとして、プロジェクトの初期でモデルのアーキテクチャや問題設定(モデルの入出力) などの大方針を広範に・網羅的に探索したいときの圧倒的なスケールアウトは、やはり頼りになります。 「来週の中間報告までに実験結果を揃えたいから、ここは奮発してNVIDIA H100 (GPU) を8枚使おう」というように、 状況に応じて柔軟にスケールアップできるのもありがたいです。
モデルの性能改善に注力できるフレームワークPyTorch Lightning
チームでは、基本的なモデルや学習ループなどの定型コードを共有ライブラリにまとめて活用していますが、 最近ではPyTorch Lightningの利用も進みつつあります。 改めて紹介するまでもないですが、PyTroch Lightningは深層学習のフレームワークでPyTorchのラッパーです。 複雑で手間がかかりバグを埋め込みがちな実装を払い出せるため、機械学習モデルのデザインや改善に注力できます。 生のPyTorchで多くの開発者が通ってきた道(実装やバグ潰し)が、OSSという形で広く共有されているのは本当に有り難いです。 特に個人的に感心したポイントは次のとおりです。
デバイス特有のコードを書かなくてよい
動作する環境を認識し、それに応じた処理分岐をしてくれます。 GPUなどのアクセラレータが検出されれば自動的に利用してくれますし、 コードの追加なしに複数台のGPUにも対応します。 これにより、ローカルでコーディングとテストをした後に、クラスタにそのままデプロイして動作します。 また、省メモリや高速化の目的で半精度浮動小数点(いわゆるFP16)を使うのも1行の設定でできます*3。 FP16での精度不足から生じるオーバーフロー・アンダーフローについても、 それを防ぐためのスケール処理も必要に応じて勝手にやってくれているようです。
学習ログやチェックポイントの保存機能が充実している
TensorBoardはもちろん、 MLflowやWeights & Biases など複数のプラットフォームにそれぞれ対応するロガーが実装されており、 それらロガーを差し込むだけで学習ログが記録できます。 また、チェックポイントの保存でも状態の保存漏れが少なく、 「あれ? このモジュールの状態は保存されているんだっけ?」という心配がありません。 自前の実装だと、例えばoptimizerやlearning rate schedulerの状態が漏れているせいで実験結果がおかしくなったり、 分散学習(複数ノード and/or 複数GPU)での実装に骨が折れたりするので助かります。 学習ログやチェックポイントをローカルと同じ感覚でリモートストレージ*4に保存できるのも嬉しいです(保存先のパスにバケットを指定するだけ)。 TensorBoardは同様にそれらバケットを参照できるので、学習ログは次のコマンドで可視化できます (一例としてGCSの場合)。
% tensorboard --logdir gs://my-bucket/experiments/
学習ログの数によってはTensorBoardの動作が遅くなります。 その場合はローカルにログファイルを同期すると良いでしょう。
# チェックポイントは除外して、リモートのストレージをローカルのディレクトリに同期する % gcloud storage rsync gs://my-bucket/experiments ./experiments/ --recursive --exclude='.*\.ckpt' % tensorboard --logdir ./experiments
細かなユーティリティ関数やクラスの実装がある
細かなポイントをカバーするユーティリティ関数やクラスの実装があるのもありがたいです。
例えば、再現性を担保するために乱数のシードを各所で設定する必要がありますが、
これをまとめて行うメソッド lightning.seed_everything()
が提供されており、設定漏れのリスクが低減できます。
また、AUCなどの「ミニバッチごとの結果を保存しておいて最後に算出する」といったような、
算出が比較的面倒なメトリックに対して算出用のクラスが一通り揃っており*5、
車輪の再発明の手間とそれに伴うバグ発生リスクを削減できます。
ハマりポイントとワークアラウンド
ここまでPyTorch Lightningの便利さを述べてきましたが、 Kubernetesクラスタとの組み合わせは比較的ニッチで、 それほどノウハウが普及していない印象があります。 ここでは、私がハマったポイントとそのワークアラウンドを紹介します。
学習ログファイルを正しく保存する
TensorBoard用の学習ログファイルについて、保存先にリモートストレージのバケットを指定できることは前述しましたが、 実はバグにより正しく学習ログが保存・閲覧できないケースがあります*6。
原因は次のとおりです。 学習ログファイルの保存では、はじめに対象ファイル名を決めておき、 特定の間隔やタイミングで都度ログの内容を追記します。 このとき、書き込みごとにファイルをクローズします。 ローカルのファイルシステムではクローズしたファイルを再オープンして追記できますが、 リモートストレージのオブジェクトに対して単純に同様の処理を行う際に、 同名のオブジェクトに(既存の内容をマージすることなく) 新しいログの内容を上書きアップロードしてしまっています。 本来、オブジェクトへの追記機能が提供されている場合もありますが*7、 Python向けのファイルシステムの汎用インタフェースである fsspec, およびその詳細実装*8に操作が隠蔽された結果、仕様にそぐわない使い方になってしまっているようです。
ワークアラウンドとしては、tensorboardX
に実装されているSummaryWriter
クラスをインスタンス化し、ロガーの当該メンバーと差し替えます*9。
tensorboardXは、PyTorchなどTensorflow以外の深層学習フレームワークから
TensorBoard用の学習ログファイルを保存することを目的としたライブラリです。
同ライブラリでは保存クラスが独自に実装されており、
保存内容をローカルにバッファしつつ最終結果としてリモートストレージにアップロードすることで、
学習ログを正常に記録できます。
学習を継続実行する
Kubernetesクラスタでは個々のノード(計算機)を意識せずに利用できる反面、 例えばノードのメンテナンスによる停止やワークロードの再配置など、 クラスタ側の都合で予告なくJobが中断されることがあります (eviction)。 また、いくつかのクラウドでは、 可用性が保証されないかわりに安価に利用可能なスポットインスタンス (Spot VM) が提供されています。 GPUや大容量のメモリが重要になるケースで特に大きな恩恵を受けるため積極的に活用したい*10ところですが、 こちらもやはり余剰なノードが枯渇した場合など、クラウド側の都合で停止することがあります (preemption)。
ワークロード (Pod) を実行しているノードが停止したとしても、
新しいノード上で再び実行がスケジューリングされますし、
チェックポイントの保存と読み込みを行うことで中断時点から学習の再開ができます。
しかしながら、これを繰り返した結果リトライ上限 (spec.backoffLimit
) に達すると、Jobは終了してしまいます*11。
もちろん単純にこの上限数を増やすことで回避できますが、
一方でコードのバグのようなクラスタ以外の要因でリトライされ続け、
不要な占有や利用料が発生することも避けたいでしょう*12。
これに対して、spec.podFailurePolicy
を設定することで、
Podのdisruption(evictionあるいはpreemptionによるPodの途中終了)をリトライ回数にカウントしないように設定できます*13。
また、途中終了したPodが再スケジュールされずにJobが正常終了してしまうケースがあります(もちろん学習は途中です*14)。
実は、Podのdisruptionに伴い、その予告としてPod内のコンテナにSIGTERM(終了要求のシグナル)が送られています。
これはAPIやバッチが安全停止 (Graceful shutdown) するためのものですが、
PyTorch Lightningでもご丁寧にSIGTERMのハンドラが実装されており、
独自のPython例外 SIGTERMException
をraiseします。
ただしここで問題なのが、同例外のraise時に終了コード=0で終了していることです
(親クラスのSystemExit
でそうなっている*15)。
つまり、Kubernetesからは「正常終了しておりリトライ不要」と思われているわけです。
対策としては、単純ですがこの例外をキャッチして0以外の終了コードで終了します。
try: trainer.fit( ... ) except SIGTERMException: # 終了コードは https://tldp.org/LDP/abs/html/exitcodes.html を参考にした。 # 単にリトライさせる目的であれば、おそらく1-255の範囲で何でも良いはず sys.exit(128 + 15)
まとめ
AI・機械学習チームでの医療AIの開発について、深層学習のためのKubernetesクラスタ環境や、 そこでのちょっとニッチなノウハウを紹介しました。 Kubernetesクラスタもラッパーライブラリも面倒な部分を隠蔽してくれる一方で、 ニッチなユースケースでは挙動が良くわからなかったりバグがあったりして、 ちょくちょくdeep-diveする必要があり、その度に学びがありました*16。 本記事を皮切りに、ノウハウ共有やコードベースへのcontributionなど、 チームひいてはコミュニティに還元していければと思います。
We are hiring !!
エムスリーAI・機械学習チームでは、機械学習のアルゴリズムはもちろん、 その実行環境にこだわりのある仲間を歓迎しています。 新卒・中途それぞれの採用だけでなく、カジュアル面談やインターンも常時募集しています!
エンジニア採用ページはこちら
カジュアル面談もお気軽にどうぞ
インターンも常時募集しています
*1:その他、エムスリーテックブログで公開している記事としては、抗がん剤の副作用をAIで予測するや臨床現場で使われるAIを作る: 胸部X線診断AIの事例と医療画像分類の特徴などがあります
*2:例えばcruft実践入門 ~cookiecutter templateの変更に追従する~を参照ください
*3:学習を扱うクラス Trainer のコンストラクタにパラメータ precision='16-mixed' を渡すだけ
*4:オンプレミスのものもありますし、クラウドサービスではGCS (Google Cloud Storage) やS3 (Amazon Simple Storage Service) が有名ですね
*5:正確にはPyTorch Lightning本体ではなく、torchmetricsというパッケージに分かれています
*6:例えばGCSについてpytorch-lightning#17037 (GitHub)で報告されています。後述のワークアラウンドも同issueからの引用です。S3など他のストレージについては未確認です
*7:例えばGCSではResumable uploadsがあります
*9:なお、TensorBoardがインストールされていない、かつtensorboardXがインストールされている場合は後者のSummaryWriterクラスが自動的に利用されるので(GitHub)、そもそもこのバグの影響を受けません
*10:例えばGKE (Google Cloud Kubernetes Engine) のスポットインスタンスには識別用のラベルが付与されており、このラベルを手がかりにワークロードをスケジュールする設定が可能です。詳細はGoogle Cloudのドキュメントを参照ください
*11:Jobをデプロイして終業し、翌日に途中で事切れたJobを見ると悲しい気持ちになります
*12:こちらも発見時に顔が真っ青になります
*13:詳しくはKubernetesのリファレンスを参照ください
*14:"[rank: 0] Received SIGTERM: 15" のログとともに何事も無かったかのように終了しており、実験が完了していることを期待してログを見ると膝から崩れ落ちます
*15:これに対してpytorch-lightning#19916で、そもそも当該ケースでの終了コードが0以外であるべきでは、という議論もあります
*16:深層学習だけに