Files
realtime_voice_bot/python/loopback_stt_worker.py

134 lines
3.8 KiB
Python

import base64
import json
import os
import sys
import traceback
from typing import Any
import numpy as np
from faster_whisper import WhisperModel
def resolve_model() -> WhisperModel:
model_name = os.environ.get("WHISPER_MODEL", "large-v3-turbo")
requested_device = os.environ.get("WHISPER_DEVICE", "auto")
requested_compute = os.environ.get("WHISPER_COMPUTE_TYPE", "auto")
attempts: list[tuple[str, str]] = []
if requested_device == "auto":
if requested_compute == "auto":
attempts.extend(
[
("cuda", "float16"),
("cuda", "int8_float16"),
("cpu", "int8"),
("cpu", "float32"),
]
)
else:
attempts.extend(
[
("cuda", requested_compute),
("cpu", requested_compute),
]
)
else:
if requested_compute == "auto":
compute = "float16" if requested_device == "cuda" else "int8"
else:
compute = requested_compute
attempts.append((requested_device, compute))
last_error: Exception | None = None
for device, compute_type in attempts:
try:
model = WhisperModel(model_name, device=device, compute_type=compute_type)
setattr(model, "_resolved_device", device)
setattr(model, "_resolved_compute_type", compute_type)
return model
except Exception as error: # noqa: BLE001
last_error = error
assert last_error is not None
raise last_error
MODEL = resolve_model()
LANGUAGE = os.environ.get("WHISPER_LANGUAGE", "ko")
BEAM_SIZE = int(os.environ.get("WHISPER_BEAM_SIZE", "1"))
def write(payload: dict[str, Any]) -> None:
sys.stdout.write(json.dumps(payload, ensure_ascii=False) + "\n")
sys.stdout.flush()
def transcribe_pcm16_base64(pcm16_base64: str) -> str:
audio_bytes = base64.b64decode(pcm16_base64)
audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
segments, _info = MODEL.transcribe(
audio,
language=LANGUAGE,
task="transcribe",
beam_size=BEAM_SIZE,
condition_on_previous_text=False,
vad_filter=False,
without_timestamps=True,
word_timestamps=False,
temperature=0.0,
)
text_parts: list[str] = []
for segment in segments:
if segment.text:
text_parts.append(segment.text.strip())
return " ".join(part for part in text_parts if part).strip()
for raw_line in sys.stdin:
line = raw_line.strip()
if not line:
continue
request = json.loads(line)
request_id = request["id"]
method = request["method"]
params = request.get("params", {})
try:
if method == "ping":
write(
{
"id": request_id,
"result": {
"model": os.environ.get("WHISPER_MODEL", "large-v3-turbo"),
"device": getattr(MODEL, "_resolved_device", "unknown"),
"compute_type": getattr(MODEL, "_resolved_compute_type", "unknown"),
},
}
)
continue
if method == "transcribe":
text = transcribe_pcm16_base64(params["pcm16_base64"])
write(
{
"id": request_id,
"result": {
"text": text,
},
}
)
continue
raise RuntimeError(f"unknown method: {method}")
except Exception as error: # noqa: BLE001
traceback.print_exc(file=sys.stderr)
write(
{
"id": request_id,
"error": f"{type(error).__name__}: {error}",
}
)