AIにテキストを生成させるサービスでは、テキスト生成の速度が遅いとユーザー体験がよくない。ChatGPTは、生成されたテキストが一度に表示されるのではなく、少しずつ表示されるUIになっている。このことで、テキスト生成の完了まで待たずに結果を確かめられる。
テキスト生成の結果を少しずつ表示するためには、生成途中のテキストがストリーミングで返される必要がある。OpenAIのAPIにはstream
オプションがあり、有効にするとserver-sent eventsで結果を返してくれる。
huggingface/transformersのStreamer
OpenAI APIではなく、自分たちでhuggingface/transformersを動かしている場合、ストリーミングは困難だった。transformersにはストリーミングをサポートするAPIがなかったためである。
ところが、最近になってこれをサポートする実装が入ってきた。Hugging FaceのJoão Ganteさんが精力的に取り組んでいるようだ。
Fresh out of a PR: 🤗 transformers supports streaming with `generate()` 🔥
— João Gante (@joao_gante) 2023年3月30日
This means you can now stream the output of your text-to-text, speech-to-text, or image-to-text model -- building responsive demos and applications on top of 🤗 transformers just became easier!
🧵 1/ pic.twitter.com/Gkbz7xqFhJ
ただし今日(2023年4月10日)時点ではまだリリースされておらず、おそらくv4.28.0
以降でリリースされるものと思われる。今回はソースからtransformersをインストールすることで、ストリーミングを試してみる。
transformersのストリーミング機能は、ドキュメントのUtilities for Generationで紹介されており、TextStreamer
あるいはTextIteratorStreamer
をmodel.generate
に渡す形で利用する。TextStreamer
は標準出力に書き出すもので、TextIteratorStreamer
はイテレータとして利用するものだ。
サーバーからストリーミングする
実際にFastAPIを使ったサーバーでTextIteratorStreamer
を使う。まずは次のように、非同期ジェネレータとして関数generate
を定義する。
import asyncio from threading import Thread from typing import AsyncIterator from transformers import ( AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer, ) tokenizer = AutoTokenizer.from_pretrained("rinna/japanese-gpt2-medium", use_fast=False) tokenizer.do_lower_case = True model = AutoModelForCausalLM.from_pretrained("rinna/japanese-gpt2-medium") async def generate(text: str) -> AsyncIterator[str]: inputs = tokenizer(text, add_special_tokens=False, return_tensors="pt") streamer = TextIteratorStreamer(tokenizer) generation_kwargs = dict( **inputs, streamer=streamer, max_new_tokens=100, do_sample=True, top_k=50, top_p=0.95, temperature=0.9, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, bad_words_ids=[[tokenizer.bos_token_id]], num_return_sequences=1, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() for output in streamer: if not output: continue await asyncio.sleep(0) yield output
TextIteratorStreamer
をmodel.generate
に渡している。ドキュメントに倣ってmodel.generate
は別スレッドで実行し、TextIteratorStreamer
をイテレータとして回してyield
している。TextIteratorStreamer
は生成されたテキストを内部である程度バッファリングして、きりのよいところで返してくれる。ややワークアラウンドのようにも見えるが、ループの中でasyncio.sleep
を行って、他の処理をブロッキングしないようにする。
FastAPIのエンドポイントからは次のようにStreamingResponse
を返す。
from fastapi import FastAPI from fastapi.responses import StreamingResponse from pydantic import BaseModel from server.generator import generate app = FastAPI() class GenerateInput(BaseModel): text: str @app.post("/generate") async def generate_post(generate_input: GenerateInput): return StreamingResponse(generate(generate_input.text), media_type="text/plain")
StreamingResponse
は非同期イテレータを受け取るので、これで十分機能する。この実装では単純にtext/plain
のレスポンスがストリームで返るので、クライアントが特別に対応していない場合でも互換性がある。より構造化されたレスポンスを返したい場合は、server-sent eventsとして返す方がよさそうだ。
フロントエンドでストリーミングされたデータを表示する。
フロントエンドでもストリーミングで表示するには少し工夫が要る。次のように非同期ジェネレータを書く。
async function* generate(text: string): AsyncGenerator<string> { const res = await fetch('https://example.com/generate', { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ text }), }) const reader = res.body?.getReader() if (!reader) { return } const decoder = new TextDecoder() while (true) { const { value, done } = await reader.read() if (done) { break } yield decoder.decode(value) } }
普通ならawait res.text()
とするところで、本文の読み込みストリームを取得する。そこから逐次的に値を読み取って、TextDecoder
でデコードしつつyield
していく。
Reactなら、簡略化した例だが、次のように使う。
'use client' import React from 'react' async function* generate(text: string): AsyncGenerator<string> { (省略) } export function Generator() { const [input, setInput] = React.useState('') const onChange: React.ChangeEventHandler<HTMLInputElement> = (event) => { setInput(event.currentTarget.value) } const [output, setOutput] = React.useState('') const onSubmit: React.FormEventHandler<HTMLFormElement> = async (event) => { event.preventDefault() setOutput('') for await (const chunk of generate(input)) { setOutput((prev) => prev + chunk) } } return ( <div> <form onSubmit={onSubmit}> <input onChange={onChange} /> <button type="submit"> Generate </button> </form> <div>{output}</div> </div> ) }
for await...of
で非同期ジェネレータから情報を読み出している。
Time to First Byte (TTFB)
ストリーミングの様子は、Webインスペクタのネットワークタブでも確かめられる。ストリーミングしないと「待機中」のまま長く待たされるが、ストリーミングすると「待機中」はすぐに終わり、「ダウンロード」が続く。つまりTTFBが改善されると言い換えられる。
いかがでしたか
huggingface/transformersのStreamerによって、簡単にテキスト生成結果をストリーミングできることがわかった。まだリリースされていない機能で、リリースまでにAPIが変わるかもしれないが、リリースされるのが楽しみだ。