diff --git a/wyoming-faster-whisper.py b/wyoming-faster-whisper.py new file mode 100644 index 0000000..c975f5f --- /dev/null +++ b/wyoming-faster-whisper.py @@ -0,0 +1,89 @@ +"""Wyoming protocol STT wrapper for faster-whisper-server.""" +import argparse, asyncio, io, logging, wave, struct +import aiohttp +from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop +from wyoming.asr import Transcribe, Transcript +from wyoming.event import Event +from wyoming.info import Attribution, AsrModel, AsrProgram, Describe, Info +from wyoming.server import AsyncEventHandler, AsyncServer + +_LOGGER = logging.getLogger(__name__) +INFO = None + +class WhisperHandler(AsyncEventHandler): + def __init__(self, reader, writer, cli_args, *a, **kw): + super().__init__(reader, writer, *a, **kw) + self.cli_args = cli_args + self.audio_buf = bytes() + self.converter = AudioChunkConverter(rate=16000, width=2, channels=1) + + async def handle_event(self, event: Event) -> bool: + if Describe.is_type(event.type): + await self.write_event(INFO.event()) + return True + if Transcribe.is_type(event.type): + self.audio_buf = bytes() + return True + if AudioChunk.is_type(event.type): + chunk = AudioChunk.from_event(event) + chunk = self.converter.convert(chunk) + self.audio_buf += chunk.audio + return True + if AudioStop.is_type(event.type): + _LOGGER.info("Transcribing %d bytes of audio", len(self.audio_buf)) + text = await self._transcribe() + await self.write_event(Transcript(text=text).event()) + return True + return True + + async def _transcribe(self) -> str: + wav_buf = io.BytesIO() + with wave.open(wav_buf, "wb") as wf: + wf.setnchannels(1) + wf.setsampwidth(2) + wf.setframerate(16000) + wf.writeframes(self.audio_buf) + wav_bytes = wav_buf.getvalue() + try: + data = aiohttp.FormData() + data.add_field("file", wav_bytes, filename="audio.wav", content_type="audio/wav") + data.add_field("model", self.cli_args.model) + data.add_field("language", self.cli_args.language) + data.add_field("prompt", "Hey Homer, wie geht es dir? Mach das Licht an. Wie spät ist es? Wie wird das Wetter?") + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.cli_args.whisper_url}/v1/audio/transcriptions", + data=data, timeout=aiohttp.ClientTimeout(total=30) + ) as resp: + result = await resp.json() + return result.get("text", "") + except Exception: + _LOGGER.exception("Transcription failed") + return "" + +async def main(): + global INFO + parser = argparse.ArgumentParser() + parser.add_argument("--port", type=int, default=10202) + parser.add_argument("--whisper-url", default="http://10.2.1.104:8005") + parser.add_argument("--model", default="deepdml/faster-whisper-large-v3-turbo-ct2") + parser.add_argument("--language", default="de") + args = parser.parse_args() + + attr = Attribution(name="faster-whisper", url="https://github.com/SYSTRAN/faster-whisper") + INFO = Info( + asr=[AsrProgram( + name="faster-whisper", description="Faster Whisper (large-v3-turbo, German)", + attribution=attr, installed=True, version="1.0", + models=[AsrModel(name="large-v3-turbo", description="Large V3 Turbo", + attribution=attr, version="1.0", + languages=["de","en"], installed=True)], + )] + ) + server = AsyncServer.from_uri(f"tcp://0.0.0.0:{args.port}") + _LOGGER.info("Wyoming faster-whisper on port %d, model=%s, lang=%s", args.port, args.model, args.language) + await server.run(lambda r,w: WhisperHandler(r, w, args)) + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main())