- はじめに
- TL;DR
- gokart とは
- luigi.Config の利用方法
- gokart.SerializableParameter を利用したパラメーターの共通化
- まとめ
- We are hiring!
はじめに
こちらはエムスリー Advent Calendar 2024 3 日目の記事になります。 AI・機械学習チームの山本(@hiro_o918)です。
今回は Python 3.8 から導入された Protocol を利用して、gokart で任意のパラメーターを利用できるように機能拡張を行った話をします。
TL;DR
この記事では次のように任意のパラメーターオブジェクトを gokart のパラメーターとして利用する方法を紹介します。
class AddTask(gokart.TaskOnKart): config: ComputeParams = gokart.SerializableParameter(object_type=ComputeParams) def run(self): self.dump(self.config.a + self.config.b) config = ComputeParams(a=1, b=2) gokart.build(AddTask(config=config))
少し長くなってしまったので、気になる内容に応じて読み進めていただければ幸いです。
- 本機能を作る上での前提となる gokart の機能や課題感から気になる方:
- このまま読み進めてください
- gokart をすでに利用していて端的に利用方法を知りたい方:
- 仕組みと合わせて利用方法を知りたい方:
なお、今回の機能追加の PR は次のリンクから確認でき、gokart v1.8.0 から利用可能です。
gokart とは
gokart は、エムスリーが開発する機械学習パイプラインツールです。 Spotify 社が開発する luigi をラップして機能を追加しています。
ちょうど Advent Calender 2 日目で池嶋さんが書かれた記事でもその特徴が紹介されていたので、ぜひご覧ください。
まず、本題の今回の機能追加の内容に触れる前に、関連する gokart の特徴について簡単に触れておきます。
gokart におけるキャッシュの仕組み
先ほどの記事の「解決 3: 必要な部分だけを逆算して再実行」にもある gokart の特徴として、入力パラメーターに対してそのタスクの結果をキャッシュとして保存するという機能があります*1。
例として、次のような 2 つの int
パラメーターを受けて、その和を計算するタスクを考えます。
import gokart import luigi class AddTask(gokart.TaskOnKart[int]): a: int = luigi.IntParameter() b: int = luigi.IntParameter() def run(self): self.dump(self.a + self.b)
gokart は a
と b
のパラメーターの組み合わせに対して、その結果をキャッシュとして保存します。
例えば、a=1
, b=1
というパラメーターの組み合わせを過去に実行したことがある場合、再度 run
を実行することなくキャッシュから 2
(1+1
) という結果を取得します。
この機能は、入力パラメーターに対してあるタスクの結果は冪等であるという前提に立っており、結果の再利用によって効率の良い開発体験を提供できます。
また、冪等な処理を記述する力学が働くのはコードの保守性を向上させる上でも非常に有用です*2。
なお、タスクに型を記述する方法は最近追加された機能で、以下の記事で詳しく紹介されています。 www.m3tech.blog 型安全なコードを書くことができ、コードの保守性を向上させることができるので、ぜひ活用してみてください。
gokart におけるパラメーターの再利用の仕組み
次に、同じパラメーターを複数箇所で再利用する仕組みについて触れていきます。
gokart(luigi)ではパラメーター共通化の仕組みとして luigi.Config
という機能を提供しています。
次の luigi のページにパラメーター共通化のテクニックが詳しく書かれています。
luigi.Config
の書き方としては次のようになります。
class ComputeParams(luigi.Config): a: int = luigi.IntParameter() b: int = luigi.IntParameter()
Config クラスの定義の方法は比較的直感的なのですが、利用する際のクセが少しあり、今回の記事の主題にも関連してくるので、次の章で詳しく説明していきます。
luigi.Config
の利用方法
それでは、luigi.Config
を使ったパラメーターの共通化の方法について詳しく触れ、今回の機能を実装するに当たって感じた私の課題感を共有していきます。
run の中でグローバルなパラメーターとして呼び出す方法
まずは、Config をグローバルなパラメーターとして呼び出す方法について説明します。
具体的な設定例
lugi.Config
で定義したグローバルなパラメーターに対しては、luigi の説明にもある通り、次のような設定ファイルから変数を割り当てることができます。
[ComputeParams] a=1 b=2
その上で、run
の中で呼び出すことで変数へのアクセスが可能です。
class ComputeParams(luigi.Config): a: int = luigi.IntParameter() b: int = luigi.IntParameter() class AddTask(gokart.TaskOnKart[int]): def run(self): config = ComputeParams() self.dump(config.a + config.b) class MulTask(gokart.TaskOnKart[int]): def run(self): config = ComputeParams() self.dump(config.a * config.b) # 処理の実行 gokart.build(AddTask()) gokart.build(MulTask())
課題
上記の方法は、パラメーターの共通化として有用に見えますが 2 点課題があります。
- run 内で生成された Config はキャッシュを生成する入力値として扱われない
- gokart におけるキャッシュの仕組みで説明したように、gokart は入力パラメーターに対してそのタスクの結果をキャッシュとして保存します。
- しかしながら、上記の方法では Config はあくまでタスク内のローカル変数として扱われるため、キャッシュの生成には利用されません。
- 言い換えると
a
とb
の値が変わっても、タスクの再実行は行われません。 - タスクの結果に影響しないパラメーターなら問題ありませんが、そうではない場合は致命的な問題になります。
- 言い換えると
- 設定ファイルから読み込んでいるという情報が明示されていない
- 慣習と言われればそれまでですが、コンストラクタや
ComputeParams
クラスで設定ファイルから読み込んでいるという情報が明示されておらず、データの流れがわかりにくいです。
- 慣習と言われればそれまでですが、コンストラクタや
継承でパラメーターを共通化する方法
次に、継承を利用してパラメーターを共通化する方法について説明します。
具体的な設定例
次のように luigi.Config
をタスクから継承することによって、パラメーターを共通化できます。
class ComputeParams(luigi.Config): a: int = luigi.IntParameter() b: int = luigi.IntParameter() class AddTask(gokart.TaskOnKart[int], ComputeParams): def run(self): self.dump(self.a + self.b) class MulTask(gokart.TaskOnKart[int], ComputeParams): def run(self): self.dump(self.a * self.b) # 処理の実行 config = ComputeParams() gokart.build(AddTask(a=config.a, b=config.b)) gokart.build(MulTask(a=config.a, b=config.b))
この設定の場合は、タスクのパラメーターとして a
と b
が認識されるため、キャッシュの生成にも利用できます。
課題
上記の方法で、キャッシュに利用可能なパラメーターの共通化は実現できましたが、次のような課題があります。
- 継承が情報としてノイジー
- 単純にパラメーターの集合を共通化して渡すことが目的であるのに対して、継承はやや過度な要求に感じる
- 変数としてはフラットに渡すことになる
gokart.build
の部分を見ると、a
とb
がフラットに渡されていることがわかります- 本来
ComputeParams
という名前付きの粒度でパラメーターを扱いたいのに、個々のパラメーターを意識したコードを書く必要があります
- 本来
なお、luigi.Config
を継承を避ける方法として、luigi ではここの記述で、inherits
と requires
を用いたパラメーターを共通化する方法が提案されています。
しかしながら、私としては単純にコンストラクタに渡す変数セットを扱すことだけに、luigi 固有の機能に大きく依存し、タスクそのものの設計に影響を与えることは避けたいと考えました。
gokart.SerializableParameter
を利用したパラメーターの共通化
前置きが長くなりましたが、ここからが今回取り組んだことの本題です。
目指すべきパラメーターの共通化方法
上記の課題を踏まえて、まずは私の中での目指すべきパラメーターの共通化方法を考えてみました。
AddTask(config=config)
のように Config 変数をまとまりとしてコンストラクタに渡せること- パラメーターをネストして管理できること
- フラットなパラメーターではなく、木構造的にパラメーターを扱えること
- 継承などを用いてタスクの設計を変える必要がないこと
- パラメーターの責務はパラメーターで閉じていること
このような要件を満たすために、gokart.SerializableParameter
というパラメータークラスを作成しました。
次で、その具体的な利用方法について説明していきます。
gokart.SerializableParameter
の使い方
次に示すのは、gokart.SerializableParameter
を利用したパラメーターの共通化の具体例です。
import json from dataclasses import asdict, dataclass import gokart @dataclass(frozen=True) class ComputeParams: a: int b: int def gokart_serialize(self) -> str: """パラメーターとしてのハッシュを作るために str に変換する関数を実装する ただし deserialize 可能な文字列である必要はない """ return json.dumps(asdict(self)) @classmethod def gokart_deserialize(cls, s: str) -> 'ComputeParams': """CLI からの入力をパースする関数を実装する""" return cls(**json.loads(s)) class AddTask(gokart.TaskOnKart): config: ComputeParams = gokart.SerializableParameter(object_type=ComputeParams) def run(self): self.dump(self.config.a + self.config.b) class MulTask(gokart.TaskOnKart): config: ComputeParams = gokart.SerializableParameter(object_type=ComputeParams) def run(self): self.dump(self.config.a * self.config.b) config = ComputeParams(a=1, b=2) gokart.build(AddTask(config=config)) gokart.build(MulTask(config=config))
この例を見ると、前述した目指すべきパラメーターの共通化方法を満たしていることがわかります。
今回の例ではパラメーターのネストは行ってはいませんが、それを行うための技術上の制約はありません。
利用上キモとなる点は、パラメータークラスである ComputeParams
が次の 2 つのメソッドを実装していることのみです。
gokart_serialize
- パラメーターとしてのハッシュを作るために
str
に変換する関数の実装
- パラメーターとしてのハッシュを作るために
gokart_deserialize
- CLI からの入力をパースする関数の実装
例では @dataclass
を利用していますが、これは設定データを扱いやすいという理由で選んだにすぎず、任意のクラスを利用できます。
gokart.SerializableParameter
の仕組み
gokart.SerializableParameter
の実装を見ながら、内部の仕組みについて説明していきます。
コードは実際のものから一部抜粋したものです。
from typing import Generic, TypeVar, Protocol T = TypeVar('T') class Serializable(Protocol): def gokart_serialize(self) -> str: ... @classmethod def gokart_deserialize(cls: type[T], s: str) -> T: ... S = TypeVar('S', bound=Serializable) class SerializableParameter(luigi.Parameter, Generic[S]): def __init__(self, object_type: type[S], *args, **kwargs): ...
gokart.SerializableParameter
では、コンストラクタの引数である object_type
に Serializable
という Protocol を満たしたクラスを指定することを期待しています。
Protocol は、Python 3.8 から導入された型ヒントの 1 つで、あるクラスが特定のメソッドを持っていることを保証するために利用されます。
Go 言語の Interface や Rust の Trait に近い概念で、明示的な継承をせずに、あるクラスが特定のメソッドを持っていることを保証するための機能です。
Python の既存の機能としては、abc.ABC
がありますが、不要な継承を避けることができるため Protocol を利用することが個人的には好ましいです。
もちろん、同じ処理の共通化のために継承を利用する場面は出てくると思いますが、その決定をユーザーに委ねることができるのは Protocol の利点だと思います。
タスクのパラメーターとして求められる機能を、Config オブジェクト側の実装に全て押し付けることができるため、 ユーザーは任意の Config オブジェクトを利用できるようになりました。
Tips: 一部のパラメーターをキャッシュのキーに含めない方法
最後に、gokart.SerializableParameter
際に、一部のパラメーターをキャッシュのキーに含めない方法について説明します。
例えば、次のような token
の変更でタスクの再実行はしたくないが、version
の変更で再実行したいという APIConfig
があるとします。
@dataclass(frozen=True) class APIConfig: version: str token: str
この場合では次のように gokart_serialize
を実装することで、token
をキャッシュのキーに含めないようにできます。
@dataclass(frozen=True) class APIConfig: version: str token: str def gokart_serialize(self) -> str: return json.dumps({'version': self.version})
これはあくまで gokart_serialize
がキャッシュキーを算出する際に利用される関数であるためです。
これにより token
の変更でタスクの再実行が行われないようにできます。
まとめ
今回は、Python の Protocol を利用して gokart のタスクで任意のパラメーターを利用できるようにする機能を追加した話をしました。 また、この機能を追加するに至った、既存機能の背景や課題感についても触れました。 今回の機能をベースにすることで好きなライブラリである Pydantic を設定オブジェクトとして利用できるようになるため、 これからの gokart の利用がより楽しくなると感じています。
gokart 周りの設定共通化で悩んだ事がある人または Protocol を利用した設計に興味がある人にとって参考になれば幸いです。
We are hiring!
エムスリーでは自社で開発している OSS である gokart を活用して、100 を超えるマイクロサービスを運用しています。 多くの場面で使われているフレームワークである一方で、自社 OSS であるため自分が取り入れたい機能はメンバーと密に議論しながらガシガシ変更を加えることができる環境です。 積極採用中なので少しでも気になる方は、ご応募・カジュアル面談をお待ちしております!