Python asyncioのTaskGroup実践—gatherから乗り換えるべき理由

Python asyncioのTaskGroup実践—gatherから乗り換えるべき理由 | mohablog

先日、100件ほどの外部APIを並列に叩くバッチを書いていて、1件だけHTTPエラーが返ったのに他のタスクが走り続けてログが混線した、という事故に遭遇しました。原因はasyncio.gatherの挙動で、失敗したタスクの例外は拾えても、残りのタスクはキャンセルされずに実行され続けていたからです。Python 3.11で入ったasyncio.TaskGroupを使えば、この手のトラブルはほぼ起きなくなります。今回はgatherTaskGroupの挙動差を実コードで並べ、なぜ新規コードはTaskGroupに寄せたほうがいいのかを掘り下げます。検証はPython 3.12で行いました。

目次

asyncio.gatherの落とし穴—例外を飲み込んだまま走り続ける

まずは、なぜTaskGroupが必要になったのかをgatherの挙動から振り返ります。ここを押さえておかないと、TaskGroupのありがたみが伝わりづらいですね。

1つ失敗しても他は止まらない

こんな書き方、見たことがあると思います。

import asyncio

async def fetch(i: int) -> str:
    await asyncio.sleep(0.1)
    if i == 2:
        raise RuntimeError(f"task {i} failed")
    print(f"task {i} done")
    return f"result {i}"

async def main() -> None:
    results = await asyncio.gather(
        fetch(1), fetch(2), fetch(3), fetch(4), fetch(5)
    )
    print(results)

asyncio.run(main())

実行してみると、こうなります。

task 1 done
task 3 done
task 4 done
task 5 done
Traceback (most recent call last):
  ...
RuntimeError: task 2 failed

task 2が失敗しているのに、3・4・5は動き切ってから例外が投げられています。外部APIへのPOSTだったら、部分的に書き込みが走った状態で中途半端に終わるわけです。「失敗した時点で全部止めたい」というケースには、gatherは素直に応えてくれません。

return_exceptionsの罠

gather(..., return_exceptions=True)を付けると、例外も戻り値と同じリストに混ざって返ってきます。一見便利そうですが、例外をリストに混ぜるので、呼び出し側で逐一型チェックする必要が出てきます。

results = await asyncio.gather(
    fetch(1), fetch(2), fetch(3),
    return_exceptions=True,
)
for r in results:
    if isinstance(r, Exception):
        # ここで個別にハンドリング
        print("error:", r)
    else:
        print("ok:", r)

このパターン、書き捨てのスクリプトならいいんですが、コードベースが大きくなると「どの戻り値が成功でどれが失敗か」を毎回分岐する羽目になります。うっかり分岐を忘れると、例外オブジェクトに.upper()を呼ぶようなバグが仕込まれやすいです。

孤児タスクのリスク

もう一つ厄介なのが、gatherの途中で呼び出し元がキャンセルされたとき、中で走っていたタスクの扱いです。Python 3.7以降のgatherは、外側がキャンセルされると内側のタスクもキャンセルしますが、asyncio.create_taskで独立に生やしたタスクをgatherで待っている場合、キャンセル伝播の挙動が直感に反することがあります。リソース解放を確実にしたい場面では、手でtry/finallyを書く量が増えていきます。

TaskGroupの基本—with文でまとめて管理する

ここからが本題です。TaskGroupはasync withのコンテキストマネージャで、中で生やしたタスクをすべて管理してくれます。

最小のサンプル

import asyncio

async def fetch(i: int) -> str:
    await asyncio.sleep(0.1)
    if i == 2:
        raise RuntimeError(f"task {i} failed")
    print(f"task {i} done")
    return f"result {i}"

async def main() -> None:
    async with asyncio.TaskGroup() as tg:
        tasks = [tg.create_task(fetch(i)) for i in range(1, 6)]
    # with を抜けた時点で全タスク完了が保証される
    print([t.result() for t in tasks])

asyncio.run(main())

実行すると、今度はこうなります。

task 1 done
Traceback (most recent call last):
  ...
  | ExceptionGroup: unhandled errors in a TaskGroup (1 sub-exception)
  +-+---------------- 1 ----------------
    | RuntimeError: task 2 failed
    +------------------------------------

task 2の失敗を検知した時点で、task 3/4/5は自動的にキャンセルされています。task 1はすでに終わっていたので「done」が出ていますが、走行中のものはすべて止まります。これが欲しかった挙動です。

ExceptionGroupでまとめて受け取る

TaskGroupは例外をExceptionGroupにまとめて投げてくれます。複数のタスクがほぼ同時に失敗したときも、すべての例外を取り逃さない設計です。捕捉するときはexcept*構文を使います。

try:
    async with asyncio.TaskGroup() as tg:
        tg.create_task(fetch(2))
        tg.create_task(fetch(3))
except* RuntimeError as eg:
    for e in eg.exceptions:
        print("caught:", e)

出力はこんな感じです。

caught: task 2 failed

except*はPython 3.11から入った新構文で、ExceptionGroupの中から特定の型だけを拾えます。TaskGroupと組み合わせると、「ネットワークエラーだけまとめてリトライしたい」といった要件が自然に書けます。

gatherとTaskGroupの比較表

ここまで見てきた違いを表にまとめておきます。

観点asyncio.gatherasyncio.TaskGroup
Python対応バージョン3.4以降3.11以降
1タスク失敗時の他タスク挙動走り続ける自動キャンセル
複数例外の扱い最初の1つだけ投げるExceptionGroupで全部拾える
戻り値の受け取りリストで一括各Taskの.result()
動的なタスク追加事前に全部渡す必要ありwith内で自由にcreate_task
タスクのライフサイクル呼び出し側で管理withスコープで保証

表にすると明らかですが、新規コードでPython 3.11以上が使えるなら、TaskGroupを第一候補にする理由は十分にあります。gatherが不要になったわけではなく、「とにかく全部実行して、エラーもまとめて結果として受け取りたい」ようなバッチ集計系ではgatherのほうが書き味がいい場面もあります。

並列数を制御する—Semaphoreと組み合わせる

TaskGroupはタスクを気軽に生やせますが、無制限に並列化すると相手サーバーにDoSをかけることになります。並列数を絞る定番はSemaphoreです。

Semaphoreでスロットを用意する

import asyncio
import httpx

sem = asyncio.Semaphore(10)  # 同時10本まで

async def fetch(client: httpx.AsyncClient, url: str) -> int:
    async with sem:
        r = await client.get(url, timeout=5.0)
        return r.status_code

async def main() -> None:
    urls = [f"https://httpbin.org/delay/1?i={i}" for i in range(50)]
    async with httpx.AsyncClient() as client:
        async with asyncio.TaskGroup() as tg:
            tasks = [tg.create_task(fetch(client, u)) for u in urls]
    codes = [t.result() for t in tasks]
    print("200 count:", codes.count(200))

50本のリクエストを10本並列で投げる形になります。httpbinのdelay/1は1秒待つエンドポイントなので、理論値はceil(50 / 10) = 5秒あたりに落ち着きます。

200 count: 50

手動でcreate_taskと組み合わせない

Semaphoreを使うときのアンチパターンが、TaskGroupの外でasyncio.create_taskを呼んでしまうケースです。

# これはNG: TaskGroupの管理外になる
tasks = [asyncio.create_task(fetch(client, u)) for u in urls]
async with asyncio.TaskGroup() as tg:
    for t in tasks:
        tg.create_task(wait_for_it(t))

この書き方だと、TaskGroupがキャンセルを伝播する対象に、外で作ったタスクが含まれません。TaskGroupの恩恵を受けるには、かならずtg.create_task()でタスクを生やすのがルールです。

CPUバウンド処理をブロックさせない—to_threadの使いどころ

asyncioはI/Oバウンド向けの仕組みで、CPUを占有する処理を書くとイベントループが止まります。ここで使うのがasyncio.to_threadです。

ブロックする関数を別スレッドに逃がす

import asyncio
import hashlib
import time

def heavy_hash(data: bytes) -> str:
    # 意図的に重くしている
    h = data
    for _ in range(200000):
        h = hashlib.sha256(h).digest()
    return h.hex()

async def main() -> None:
    start = time.perf_counter()
    async with asyncio.TaskGroup() as tg:
        tasks = [
            tg.create_task(asyncio.to_thread(heavy_hash, f"seed-{i}".encode()))
            for i in range(4)
        ]
    elapsed = time.perf_counter() - start
    print(f"elapsed: {elapsed:.2f}s")
    for t in tasks:
        print(t.result()[:16])

手元のM1 Mac(Python 3.12)で実行した結果がこちらです。

elapsed: 1.82s
b3a6f8cb7d9e1a24
...

直接await heavy_hash(...)と書いてしまうとシリアルに実行され、4倍の時間がかかります。ただし、PythonのGILの影響でCPUバウンドは真の並列化はされない点に注意してください。純粋なCPU処理ならProcessPoolExecutorloop.run_in_executorの組み合わせのほうが素直です。

スレッド版とプロセス版の選び方

  • I/O待ちが主体(HTTP、DB、ファイル)→ そのままasync defでOK、あるいはasyncio.to_thread
  • C拡張でGILを解放する処理(NumPy、hashlibの一部)→ asyncio.to_threadで十分
  • 純Pythonの重い計算(パース、暗号、大量の辞書操作)→ ProcessPoolExecutorloop.run_in_executorで呼ぶ

ベンチマーク—シリアル・gather・TaskGroupで速度を比べる

最後に、同じ処理を3パターンで実行して速度を比較してみます。asyncio.sleep(0.1)を50回実行する想定です。

import asyncio
import time

async def work(i: int) -> int:
    await asyncio.sleep(0.1)
    return i

async def run_serial() -> None:
    for i in range(50):
        await work(i)

async def run_gather() -> None:
    await asyncio.gather(*(work(i) for i in range(50)))

async def run_taskgroup() -> None:
    async with asyncio.TaskGroup() as tg:
        for i in range(50):
            tg.create_task(work(i))

async def main() -> None:
    for name, fn in [
        ("serial", run_serial),
        ("gather", run_gather),
        ("taskgroup", run_taskgroup),
    ]:
        start = time.perf_counter()
        await fn()
        print(f"{name}: {time.perf_counter() - start:.3f}s")

asyncio.run(main())

実行結果はこうなりました。

serial:    5.063s
gather:    0.105s
taskgroup: 0.106s

シリアルに50回待つと5秒かかる処理が、並列化すると約0.1秒に収まります。gatherとTaskGroupのスループットはほぼ同等で、TaskGroupに変えても性能面でのデメリットは感じません。なので選定基準は速度ではなく、エラーハンドリングの安全性とコードの読みやすさに置くのが妥当だと思います。

移行時の注意点—既存コードをTaskGroupに置き換えるとき

既存のgatherベースのコードをTaskGroupに書き換えるとき、いくつかハマりポイントがあります。

戻り値の取り方が変わる

gatherは戻り値のリストを直接返しますが、TaskGroupは各Taskから.result()で取り出します。書き換え時に変換する必要があります。

# before
results = await asyncio.gather(*(fetch(u) for u in urls))

# after
async with asyncio.TaskGroup() as tg:
    tasks = [tg.create_task(fetch(u)) for u in urls]
results = [t.result() for t in tasks]

部分的な成功を受け取るなら戻り値設計を変える

gather(return_exceptions=True)と同じ「失敗しても走り切ってほしい」要件をTaskGroupで書くには、各タスク内部でtry/exceptして成功/失敗を戻り値に包む設計にします。

from dataclasses import dataclass

@dataclass
class FetchResult:
    url: str
    ok: bool
    value: str | None = None
    error: str | None = None

async def safe_fetch(url: str) -> FetchResult:
    try:
        v = await fetch(url)
        return FetchResult(url=url, ok=True, value=v)
    except Exception as e:
        return FetchResult(url=url, ok=False, error=str(e))

この形にしておくと、TaskGroupでもgatherでも同じロジックで呼べますし、呼び出し元はresult.okで分岐するだけで済みます。例外を値として扱う設計に寄せるのは、非同期処理全般で読みやすさが上がる書き方なので、おすすめです。

3.10以下のプロジェクトへの移行

TaskGroupはPython 3.11以降なので、古い環境ではそのままは使えません。exceptiongrouptaskgroupのバックポートパッケージが公開されていますが、依存を増やすくらいならまずランタイムを3.11以上に上げるほうが筋がいいと思います。3.11はexcept*以外にも型ヒントまわりの改善が多く、入れて困ることはあまりありません。

非同期処理と相性のいい話題として、関連記事のPythonでWebSocketリアルタイム通信を実装する──サーバー・クライアント実装とよくある落とし穴もあわせて読むと、TaskGroupの実運用イメージがつかみやすいです。

まとめ

  • gatherは1タスク失敗時に他を止めないため、全体整合性が必要な処理には向かない
  • TaskGroupはwithスコープで全タスクのライフサイクルを保証し、失敗時は自動キャンセル
  • 複数例外はExceptionGroupとしてまとめられ、except*で型別に拾える
  • 並列数制御はSemaphoreと組み合わせ、タスク生成は必ずtg.create_task()で行う
  • CPUバウンドはasyncio.to_threadProcessPoolExecutorで逃がす
  • Python 3.11以上が使えるなら、新規コードはTaskGroupを第一候補にする

gatherからTaskGroupへの移行は、コードの見た目はあまり変わらないのに、エラー時の振る舞いが格段に堅牢になります。非同期処理で「なんとなく動いているけど失敗時に何が起きるか怖い」と感じたことがある人は、一度TaskGroupで書き換えてみると、設計の安心感が変わるはずです。

よかったらシェアしてね!
  • URLをコピーしました!
  • URLをコピーしました!
目次