Add wyoming-faster-whisper.py
This commit is contained in:
parent
66d5595c68
commit
45aeea9d52
1 changed files with 89 additions and 0 deletions
89
wyoming-faster-whisper.py
Normal file
89
wyoming-faster-whisper.py
Normal file
|
|
@ -0,0 +1,89 @@
|
||||||
|
"""Wyoming protocol STT wrapper for faster-whisper-server."""
|
||||||
|
import argparse, asyncio, io, logging, wave, struct
|
||||||
|
import aiohttp
|
||||||
|
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStop
|
||||||
|
from wyoming.asr import Transcribe, Transcript
|
||||||
|
from wyoming.event import Event
|
||||||
|
from wyoming.info import Attribution, AsrModel, AsrProgram, Describe, Info
|
||||||
|
from wyoming.server import AsyncEventHandler, AsyncServer
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
INFO = None
|
||||||
|
|
||||||
|
class WhisperHandler(AsyncEventHandler):
|
||||||
|
def __init__(self, reader, writer, cli_args, *a, **kw):
|
||||||
|
super().__init__(reader, writer, *a, **kw)
|
||||||
|
self.cli_args = cli_args
|
||||||
|
self.audio_buf = bytes()
|
||||||
|
self.converter = AudioChunkConverter(rate=16000, width=2, channels=1)
|
||||||
|
|
||||||
|
async def handle_event(self, event: Event) -> bool:
|
||||||
|
if Describe.is_type(event.type):
|
||||||
|
await self.write_event(INFO.event())
|
||||||
|
return True
|
||||||
|
if Transcribe.is_type(event.type):
|
||||||
|
self.audio_buf = bytes()
|
||||||
|
return True
|
||||||
|
if AudioChunk.is_type(event.type):
|
||||||
|
chunk = AudioChunk.from_event(event)
|
||||||
|
chunk = self.converter.convert(chunk)
|
||||||
|
self.audio_buf += chunk.audio
|
||||||
|
return True
|
||||||
|
if AudioStop.is_type(event.type):
|
||||||
|
_LOGGER.info("Transcribing %d bytes of audio", len(self.audio_buf))
|
||||||
|
text = await self._transcribe()
|
||||||
|
await self.write_event(Transcript(text=text).event())
|
||||||
|
return True
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _transcribe(self) -> str:
|
||||||
|
wav_buf = io.BytesIO()
|
||||||
|
with wave.open(wav_buf, "wb") as wf:
|
||||||
|
wf.setnchannels(1)
|
||||||
|
wf.setsampwidth(2)
|
||||||
|
wf.setframerate(16000)
|
||||||
|
wf.writeframes(self.audio_buf)
|
||||||
|
wav_bytes = wav_buf.getvalue()
|
||||||
|
try:
|
||||||
|
data = aiohttp.FormData()
|
||||||
|
data.add_field("file", wav_bytes, filename="audio.wav", content_type="audio/wav")
|
||||||
|
data.add_field("model", self.cli_args.model)
|
||||||
|
data.add_field("language", self.cli_args.language)
|
||||||
|
data.add_field("prompt", "Hey Homer, wie geht es dir? Mach das Licht an. Wie spät ist es? Wie wird das Wetter?")
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(
|
||||||
|
f"{self.cli_args.whisper_url}/v1/audio/transcriptions",
|
||||||
|
data=data, timeout=aiohttp.ClientTimeout(total=30)
|
||||||
|
) as resp:
|
||||||
|
result = await resp.json()
|
||||||
|
return result.get("text", "")
|
||||||
|
except Exception:
|
||||||
|
_LOGGER.exception("Transcription failed")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
global INFO
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--port", type=int, default=10202)
|
||||||
|
parser.add_argument("--whisper-url", default="http://10.2.1.104:8005")
|
||||||
|
parser.add_argument("--model", default="deepdml/faster-whisper-large-v3-turbo-ct2")
|
||||||
|
parser.add_argument("--language", default="de")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
attr = Attribution(name="faster-whisper", url="https://github.com/SYSTRAN/faster-whisper")
|
||||||
|
INFO = Info(
|
||||||
|
asr=[AsrProgram(
|
||||||
|
name="faster-whisper", description="Faster Whisper (large-v3-turbo, German)",
|
||||||
|
attribution=attr, installed=True, version="1.0",
|
||||||
|
models=[AsrModel(name="large-v3-turbo", description="Large V3 Turbo",
|
||||||
|
attribution=attr, version="1.0",
|
||||||
|
languages=["de","en"], installed=True)],
|
||||||
|
)]
|
||||||
|
)
|
||||||
|
server = AsyncServer.from_uri(f"tcp://0.0.0.0:{args.port}")
|
||||||
|
_LOGGER.info("Wyoming faster-whisper on port %d, model=%s, lang=%s", args.port, args.model, args.language)
|
||||||
|
await server.run(lambda r,w: WhisperHandler(r, w, args))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Add table
Add a link
Reference in a new issue