wyoming-chatterbox/wyoming-faster-whisper.py

89 lines
3.8 KiB
Python

"""Wyoming protocol STT wrapper for faster-whisper-server."""
import argparse, asyncio, io, logging, wave, struct
import aiohttp
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
from wyoming.asr import Transcribe, Transcript
from wyoming.event import Event
from wyoming.info import Attribution, AsrModel, AsrProgram, Describe, Info
from wyoming.server import AsyncEventHandler, AsyncServer
_LOGGER = logging.getLogger(__name__)
INFO = None
class WhisperHandler(AsyncEventHandler):
def __init__(self, reader, writer, cli_args, *a, **kw):
super().__init__(reader, writer, *a, **kw)
self.cli_args = cli_args
self.audio_buf = bytes()
self.converter = AudioChunkConverter(rate=16000, width=2, channels=1)
async def handle_event(self, event: Event) -> bool:
if Describe.is_type(event.type):
await self.write_event(INFO.event())
return True
if Transcribe.is_type(event.type):
self.audio_buf = bytes()
return True
if AudioChunk.is_type(event.type):
chunk = AudioChunk.from_event(event)
chunk = self.converter.convert(chunk)
self.audio_buf += chunk.audio
return True
if AudioStop.is_type(event.type):
_LOGGER.info("Transcribing %d bytes of audio", len(self.audio_buf))
text = await self._transcribe()
await self.write_event(Transcript(text=text).event())
return True
return True
async def _transcribe(self) -> str:
wav_buf = io.BytesIO()
with wave.open(wav_buf, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(16000)
wf.writeframes(self.audio_buf)
wav_bytes = wav_buf.getvalue()
try:
data = aiohttp.FormData()
data.add_field("file", wav_bytes, filename="audio.wav", content_type="audio/wav")
data.add_field("model", self.cli_args.model)
data.add_field("language", self.cli_args.language)
data.add_field("prompt", "Hey Homer, wie geht es dir? Mach das Licht an. Wie spät ist es? Wie wird das Wetter?")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.cli_args.whisper_url}/v1/audio/transcriptions",
data=data, timeout=aiohttp.ClientTimeout(total=30)
) as resp:
result = await resp.json()
return result.get("text", "")
except Exception:
_LOGGER.exception("Transcription failed")
return ""
async def main():
global INFO
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=10202)
parser.add_argument("--whisper-url", default="http://10.2.1.104:8005")
parser.add_argument("--model", default="deepdml/faster-whisper-large-v3-turbo-ct2")
parser.add_argument("--language", default="de")
args = parser.parse_args()
attr = Attribution(name="faster-whisper", url="https://github.com/SYSTRAN/faster-whisper")
INFO = Info(
asr=[AsrProgram(
name="faster-whisper", description="Faster Whisper (large-v3-turbo, German)",
attribution=attr, installed=True, version="1.0",
models=[AsrModel(name="large-v3-turbo", description="Large V3 Turbo",
attribution=attr, version="1.0",
languages=["de","en"], installed=True)],
)]
)
server = AsyncServer.from_uri(f"tcp://0.0.0.0:{args.port}")
_LOGGER.info("Wyoming faster-whisper on port %d, model=%s, lang=%s", args.port, args.model, args.language)
await server.run(lambda r,w: WhisperHandler(r, w, args))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())