cockscomblog?

cockscomb on hatena blog

huggingface/transformersでストリーミング

AIにテキストを生成させるサービスでは、テキスト生成の速度が遅いとユーザー体験がよくない。ChatGPTは、生成されたテキストが一度に表示されるのではなく、少しずつ表示されるUIになっている。このことで、テキスト生成の完了まで待たずに結果を確かめられる。

テキスト生成の結果を少しずつ表示するためには、生成途中のテキストがストリーミングで返される必要がある。OpenAIのAPIにはstreamオプションがあり、有効にするとserver-sent eventsで結果を返してくれる。

huggingface/transformersのStreamer

OpenAI APIではなく、自分たちでhuggingface/transformersを動かしている場合、ストリーミングは困難だった。transformersにはストリーミングをサポートするAPIがなかったためである。

ところが、最近になってこれをサポートする実装が入ってきた。Hugging FaceのJoão Ganteさんが精力的に取り組んでいるようだ。

ただし今日(2023年4月10日)時点ではまだリリースされておらず、おそらくv4.28.0以降でリリースされるものと思われる。今回はソースからtransformersをインストールすることで、ストリーミングを試してみる。

transformersのストリーミング機能は、ドキュメントのUtilities for Generationで紹介されており、TextStreamerあるいはTextIteratorStreamermodel.generateに渡す形で利用する。TextStreamerは標準出力に書き出すもので、TextIteratorStreamerイテレータとして利用するものだ。

サーバーからストリーミングする

実際にFastAPIを使ったサーバーでTextIteratorStreamerを使う。まずは次のように、非同期ジェネレータとして関数generateを定義する。

import asyncio
from threading import Thread
from typing import AsyncIterator

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
)

tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium", use_fast=False)
tokenizer.do_lower_case = True

model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium")


async def generate(text: str) -> AsyncIterator[str]:
    inputs = tokenizer(text, add_special_tokens=False, return_tensors="pt")

    streamer = TextIteratorStreamer(tokenizer)
    generation_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=100,
        do_sample=True,
        top_k=50,
        top_p=0.95,
        temperature=0.9,
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        bad_words_ids=[[tokenizer.bos_token_id]],
        num_return_sequences=1,
    )
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    for output in streamer:
        if not output:
            continue
        await asyncio.sleep(0)
        yield output

TextIteratorStreamermodel.generateに渡している。ドキュメントに倣ってmodel.generateは別スレッドで実行し、TextIteratorStreamerイテレータとして回してyieldしている。TextIteratorStreamerは生成されたテキストを内部である程度バッファリングして、きりのよいところで返してくれる。ややワークアラウンドのようにも見えるが、ループの中でasyncio.sleepを行って、他の処理をブロッキングしないようにする。

FastAPIのエンドポイントからは次のようにStreamingResponseを返す。

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from pydantic import BaseModel

from server.generator import generate

app = FastAPI()


class GenerateInput(BaseModel):
    text: str


@app.post("/generate")
async def generate_post(generate_input: GenerateInput):
    return StreamingResponse(generate(generate_input.text), media_type="text/plain")

StreamingResponseは非同期イテレータを受け取るので、これで十分機能する。この実装では単純にtext/plainのレスポンスがストリームで返るので、クライアントが特別に対応していない場合でも互換性がある。より構造化されたレスポンスを返したい場合は、server-sent eventsとして返す方がよさそうだ。

フロントエンドでストリーミングされたデータを表示する。

フロントエンドでもストリーミングで表示するには少し工夫が要る。次のように非同期ジェネレータを書く。

async function* generate(text: string): AsyncGenerator<string> {
  const res = await fetch('https://example.com/generate', {
    method: 'POST',
    headers: {
      'Content-Type': 'application/json',
    },
    body: JSON.stringify({ text }),
  })
  const reader = res.body?.getReader()
  if (!reader) {
    return
  }
  const decoder = new TextDecoder()
  while (true) {
    const { value, done } = await reader.read()
    if (done) {
      break
    }
    yield decoder.decode(value)
  }
}

普通ならawait res.text()とするところで、本文の読み込みストリームを取得する。そこから逐次的に値を読み取って、TextDecoderでデコードしつつyieldしていく。

Reactなら、簡略化した例だが、次のように使う。

'use client'

import React from 'react'

async function* generate(text: string): AsyncGenerator<string> {
  (省略)
}

export function Generator() {
  const [input, setInput] = React.useState('')
  const onChange: React.ChangeEventHandler<HTMLInputElement> = (event) => {
    setInput(event.currentTarget.value)
  }
  const [output, setOutput] = React.useState('')
  const onSubmit: React.FormEventHandler<HTMLFormElement> = async (event) => {
    event.preventDefault()
    setOutput('')
    for await (const chunk of generate(input)) {
      setOutput((prev) => prev + chunk)
    }
  }
  return (
    <div>
      <form onSubmit={onSubmit}>
        <input onChange={onChange} />
        <button type="submit">
          Generate
        </button>
      </form>
      <div>{output}</div>
    </div>
  )
}

for await...ofで非同期ジェネレータから情報を読み出している。

Time to First Byte (TTFB)

ストリーミングの様子は、Webインスペクタのネットワークタブでも確かめられる。ストリーミングしないと「待機中」のまま長く待たされるが、ストリーミングすると「待機中」はすぐに終わり、「ダウンロード」が続く。つまりTTFBが改善されると言い換えられる。

ストリーミングしない場合は待機中が続く

ストリーミングすると待機中が短くなる

いかがでしたか

huggingface/transformersのStreamerによって、簡単にテキスト生成結果をストリーミングできることがわかった。まだリリースされていない機能で、リリースまでにAPIが変わるかもしれないが、リリースされるのが楽しみだ。