エムスリーテックブログ

エムスリー(m3)のエンジニア・開発メンバーによる技術ブログです

Protocol を使って、ML パイプラインツール gokart のパラメーター拡張機能を作ってみた

はじめに

本文とは関係ない自宅警備を頑張るわんこ

こちらはエムスリー Advent Calendar 2024 3 日目の記事になります。 AI・機械学習チームの山本(@hiro_o918)です。

今回は Python 3.8 から導入された Protocol を利用して、gokart で任意のパラメーターを利用できるように機能拡張を行った話をします。

TL;DR

この記事では次のように任意のパラメーターオブジェクトを gokart のパラメーターとして利用する方法を紹介します。

class AddTask(gokart.TaskOnKart):
    config: ComputeParams = gokart.SerializableParameter(object_type=ComputeParams)

    def run(self):
        self.dump(self.config.a + self.config.b)

config = ComputeParams(a=1, b=2)
gokart.build(AddTask(config=config))

少し長くなってしまったので、気になる内容に応じて読み進めていただければ幸いです。

なお、今回の機能追加の PR は次のリンクから確認でき、gokart v1.8.0 から利用可能です。

github.com

gokart とは

gokart は、エムスリーが開発する機械学習パイプラインツールです。 Spotify 社が開発する luigi をラップして機能を追加しています。

github.com

ちょうど Advent Calender 2 日目で池嶋さんが書かれた記事でもその特徴が紹介されていたので、ぜひご覧ください。

www.m3tech.blog

まず、本題の今回の機能追加の内容に触れる前に、関連する gokart の特徴について簡単に触れておきます。

gokart におけるキャッシュの仕組み

先ほどの記事の「解決 3: 必要な部分だけを逆算して再実行」にもある gokart の特徴として、入力パラメーターに対してそのタスクの結果をキャッシュとして保存するという機能があります*1

例として、次のような 2 つの int パラメーターを受けて、その和を計算するタスクを考えます。

import gokart
import luigi

class AddTask(gokart.TaskOnKart[int]):
    a: int = luigi.IntParameter()
    b: int = luigi.IntParameter()

    def run(self):
        self.dump(self.a + self.b)

gokart は ab のパラメーターの組み合わせに対して、その結果をキャッシュとして保存します。 例えば、a=1, b=1 というパラメーターの組み合わせを過去に実行したことがある場合、再度 run を実行することなくキャッシュから 2(1+1) という結果を取得します。 この機能は、入力パラメーターに対してあるタスクの結果は冪等であるという前提に立っており、結果の再利用によって効率の良い開発体験を提供できます。 また、冪等な処理を記述する力学が働くのはコードの保守性を向上させる上でも非常に有用です*2

なお、タスクに型を記述する方法は最近追加された機能で、以下の記事で詳しく紹介されています。 www.m3tech.blog 型安全なコードを書くことができ、コードの保守性を向上させることができるので、ぜひ活用してみてください。

gokart におけるパラメーターの再利用の仕組み

次に、同じパラメーターを複数箇所で再利用する仕組みについて触れていきます。 gokart(luigi)ではパラメーター共通化の仕組みとして luigi.Config という機能を提供しています。 次の luigi のページにパラメーター共通化のテクニックが詳しく書かれています。

luigi.readthedocs.io

luigi.Config の書き方としては次のようになります。

class ComputeParams(luigi.Config):
    a: int = luigi.IntParameter()
    b: int = luigi.IntParameter()

Config クラスの定義の方法は比較的直感的なのですが、利用する際のクセが少しあり、今回の記事の主題にも関連してくるので、次の章で詳しく説明していきます。

luigi.Config の利用方法

それでは、luigi.Config を使ったパラメーターの共通化の方法について詳しく触れ、今回の機能を実装するに当たって感じた私の課題感を共有していきます。

run の中でグローバルなパラメーターとして呼び出す方法

まずは、Config をグローバルなパラメーターとして呼び出す方法について説明します。

具体的な設定例

lugi.Config で定義したグローバルなパラメーターに対しては、luigi の説明にもある通り、次のような設定ファイルから変数を割り当てることができます。

[ComputeParams]
a=1
b=2

その上で、run の中で呼び出すことで変数へのアクセスが可能です。

class ComputeParams(luigi.Config):
    a: int = luigi.IntParameter()
    b: int = luigi.IntParameter()

class AddTask(gokart.TaskOnKart[int]):
    def run(self):
        config = ComputeParams()
        self.dump(config.a + config.b)

class MulTask(gokart.TaskOnKart[int]):
    def run(self):
        config = ComputeParams()
        self.dump(config.a * config.b)

# 処理の実行
gokart.build(AddTask())
gokart.build(MulTask())

課題

上記の方法は、パラメーターの共通化として有用に見えますが 2 点課題があります。

  • run 内で生成された Config はキャッシュを生成する入力値として扱われない
    • gokart におけるキャッシュの仕組みで説明したように、gokart は入力パラメーターに対してそのタスクの結果をキャッシュとして保存します。
    • しかしながら、上記の方法では Config はあくまでタスク内のローカル変数として扱われるため、キャッシュの生成には利用されません。
      • 言い換えると ab の値が変わっても、タスクの再実行は行われません。
      • タスクの結果に影響しないパラメーターなら問題ありませんが、そうではない場合は致命的な問題になります。
  • 設定ファイルから読み込んでいるという情報が明示されていない
    • 慣習と言われればそれまでですが、コンストラクタや ComputeParams クラスで設定ファイルから読み込んでいるという情報が明示されておらず、データの流れがわかりにくいです。

継承でパラメーターを共通化する方法

次に、継承を利用してパラメーターを共通化する方法について説明します。

具体的な設定例

次のように luigi.Config をタスクから継承することによって、パラメーターを共通化できます。

class ComputeParams(luigi.Config):
    a: int = luigi.IntParameter()
    b: int = luigi.IntParameter()

class AddTask(gokart.TaskOnKart[int], ComputeParams):
    def run(self):
        self.dump(self.a + self.b)

class MulTask(gokart.TaskOnKart[int], ComputeParams):
    def run(self):
        self.dump(self.a * self.b)

# 処理の実行
config = ComputeParams()
gokart.build(AddTask(a=config.a, b=config.b))
gokart.build(MulTask(a=config.a, b=config.b))

この設定の場合は、タスクのパラメーターとして ab が認識されるため、キャッシュの生成にも利用できます。

課題

上記の方法で、キャッシュに利用可能なパラメーターの共通化は実現できましたが、次のような課題があります。

  • 継承が情報としてノイジー
    • 単純にパラメーターの集合を共通化して渡すことが目的であるのに対して、継承はやや過度な要求に感じる
  • 変数としてはフラットに渡すことになる
    • gokart.build の部分を見ると、ab がフラットに渡されていることがわかります
      • 本来 ComputeParams という名前付きの粒度でパラメーターを扱いたいのに、個々のパラメーターを意識したコードを書く必要があります

なお、luigi.Config を継承を避ける方法として、luigi ではここの記述で、inheritsrequires を用いたパラメーターを共通化する方法が提案されています。

しかしながら、私としては単純にコンストラクタに渡す変数セットを扱すことだけに、luigi 固有の機能に大きく依存し、タスクそのものの設計に影響を与えることは避けたいと考えました。

gokart.SerializableParameter を利用したパラメーターの共通化

前置きが長くなりましたが、ここからが今回取り組んだことの本題です。

目指すべきパラメーターの共通化方法

上記の課題を踏まえて、まずは私の中での目指すべきパラメーターの共通化方法を考えてみました。

  • AddTask(config=config) のように Config 変数をまとまりとしてコンストラクタに渡せること
  • パラメーターをネストして管理できること
    • フラットなパラメーターではなく、木構造的にパラメーターを扱えること
  • 継承などを用いてタスクの設計を変える必要がないこと
    • パラメーターの責務はパラメーターで閉じていること

このような要件を満たすために、gokart.SerializableParameter というパラメータークラスを作成しました。 次で、その具体的な利用方法について説明していきます。

gokart.SerializableParameter の使い方

次に示すのは、gokart.SerializableParameter を利用したパラメーターの共通化の具体例です。

import json
from dataclasses import asdict, dataclass

import gokart

@dataclass(frozen=True)
class ComputeParams:
    a: int
    b: int

    def gokart_serialize(self) -> str:
        """パラメーターとしてのハッシュを作るために str に変換する関数を実装する
        ただし deserialize 可能な文字列である必要はない
        """
        return json.dumps(asdict(self))

    @classmethod
    def gokart_deserialize(cls, s: str) -> 'ComputeParams':
        """CLI からの入力をパースする関数を実装する"""
        return cls(**json.loads(s))

class AddTask(gokart.TaskOnKart):
    config: ComputeParams = gokart.SerializableParameter(object_type=ComputeParams)

    def run(self):
        self.dump(self.config.a + self.config.b)

class MulTask(gokart.TaskOnKart):
    config: ComputeParams = gokart.SerializableParameter(object_type=ComputeParams)

    def run(self):
        self.dump(self.config.a * self.config.b)

config = ComputeParams(a=1, b=2)
gokart.build(AddTask(config=config))
gokart.build(MulTask(config=config))

この例を見ると、前述した目指すべきパラメーターの共通化方法を満たしていることがわかります。

今回の例ではパラメーターのネストは行ってはいませんが、それを行うための技術上の制約はありません。 利用上キモとなる点は、パラメータークラスである ComputeParams が次の 2 つのメソッドを実装していることのみです。

  • gokart_serialize
    • パラメーターとしてのハッシュを作るために str に変換する関数の実装
  • gokart_deserialize
    • CLI からの入力をパースする関数の実装

例では @dataclass を利用していますが、これは設定データを扱いやすいという理由で選んだにすぎず、任意のクラスを利用できます。

gokart.SerializableParameter の仕組み

gokart.SerializableParameter の実装を見ながら、内部の仕組みについて説明していきます。 コードは実際のものから一部抜粋したものです。

from typing import Generic, TypeVar, Protocol

T = TypeVar('T')

class Serializable(Protocol):
    def gokart_serialize(self) -> str:
        ...

    @classmethod
    def gokart_deserialize(cls: type[T], s: str) -> T:
        ...


S = TypeVar('S', bound=Serializable)


class SerializableParameter(luigi.Parameter, Generic[S]):
    def __init__(self, object_type: type[S], *args, **kwargs): ...

gokart.SerializableParameter では、コンストラクタの引数である object_typeSerializable という Protocol を満たしたクラスを指定することを期待しています。 Protocol は、Python 3.8 から導入された型ヒントの 1 つで、あるクラスが特定のメソッドを持っていることを保証するために利用されます。 Go 言語の Interface や Rust の Trait に近い概念で、明示的な継承をせずに、あるクラスが特定のメソッドを持っていることを保証するための機能です。 Python の既存の機能としては、abc.ABC がありますが、不要な継承を避けることができるため Protocol を利用することが個人的には好ましいです。 もちろん、同じ処理の共通化のために継承を利用する場面は出てくると思いますが、その決定をユーザーに委ねることができるのは Protocol の利点だと思います。

タスクのパラメーターとして求められる機能を、Config オブジェクト側の実装に全て押し付けることができるため、 ユーザーは任意の Config オブジェクトを利用できるようになりました。

Tips: 一部のパラメーターをキャッシュのキーに含めない方法

最後に、gokart.SerializableParameter 際に、一部のパラメーターをキャッシュのキーに含めない方法について説明します。 例えば、次のような token の変更でタスクの再実行はしたくないが、version の変更で再実行したいという APIConfig があるとします。

@dataclass(frozen=True)
class APIConfig:
    version: str
    token: str

この場合では次のように gokart_serialize を実装することで、token をキャッシュのキーに含めないようにできます。

@dataclass(frozen=True)
class APIConfig:
    version: str
    token: str

    def gokart_serialize(self) -> str:
        return json.dumps({'version': self.version})

これはあくまで gokart_serialize がキャッシュキーを算出する際に利用される関数であるためです。 これにより token の変更でタスクの再実行が行われないようにできます。

まとめ

今回は、Python の Protocol を利用して gokart のタスクで任意のパラメーターを利用できるようにする機能を追加した話をしました。 また、この機能を追加するに至った、既存機能の背景や課題感についても触れました。 今回の機能をベースにすることで好きなライブラリである Pydantic を設定オブジェクトとして利用できるようになるため、 これからの gokart の利用がより楽しくなると感じています。

gokart 周りの設定共通化で悩んだ事がある人または Protocol を利用した設計に興味がある人にとって参考になれば幸いです。

We are hiring!

エムスリーでは自社で開発している OSS である gokart を活用して、100 を超えるマイクロサービスを運用しています。 多くの場面で使われているフレームワークである一方で、自社 OSS であるため自分が取り入れたい機能はメンバーと密に議論しながらガシガシ変更を加えることができる環境です。 積極採用中なので少しでも気になる方は、ご応募・カジュアル面談をお待ちしております!

エンジニア採用ページはこちら

jobs.m3.com

カジュアル面談もお気軽にどうぞ

jobs.m3.com

インターンも常時募集しています

open.talentio.com

*1:これはラッパー元の luigi でも提供されている機能です。

*2:外部データを取得する部分など、冪等ではない処理の場合は、その部分のみをタスクとして切り出して、実行時刻をパラメーターを付与することで強制的に入力値が一致しないようにする工夫しています。