Files
realtime_voice_bot/python/local_stt_worker.py

146 lines
4.2 KiB
Python

import base64
import json
import os
import sys
import tempfile
import traceback
import wave
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
def log(message: str) -> None:
print(message, file=sys.stderr, flush=True)
def write_response(request_id: int, ok: bool, result=None, error: str | None = None) -> None:
payload = {
"id": request_id,
"ok": ok,
}
if ok:
payload["result"] = result
else:
payload["error"] = error or "unknown error"
sys.stdout.write(json.dumps(payload, ensure_ascii=False) + "\n")
sys.stdout.flush()
def resolve_device() -> str:
raw = os.environ.get("LOCAL_STT_DEVICE", "auto").strip().lower()
if raw and raw != "auto":
return raw
try:
import ctranslate2
if ctranslate2.get_cuda_device_count() > 0:
return "cuda"
except Exception:
pass
return "cpu"
def resolve_compute_type(device: str) -> str:
raw = os.environ.get("LOCAL_STT_COMPUTE_TYPE", "auto").strip().lower()
if raw and raw != "auto":
return raw
if device == "cuda":
return "int8_float16"
return "int8"
class SttWorker:
def __init__(self) -> None:
from faster_whisper import WhisperModel
self.model_name = os.environ.get("LOCAL_STT_MODEL", "tiny").strip() or "tiny"
self.device = resolve_device()
self.compute_type = resolve_compute_type(self.device)
self.beam_size = int(os.environ.get("LOCAL_STT_BEAM_SIZE", "1"))
self.model = WhisperModel(
self.model_name,
device=self.device,
compute_type=self.compute_type,
)
log(
f"local-stt ready model={self.model_name} device={self.device} compute={self.compute_type} beam={self.beam_size}"
)
def transcribe(self, audio_base64: str, language: str | None) -> str:
pcm_bytes = base64.b64decode(audio_base64)
temp_path = ""
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as handle:
temp_path = handle.name
with wave.open(temp_path, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(16000)
wav_file.writeframes(pcm_bytes)
segments, _info = self.model.transcribe(
temp_path,
language=language,
beam_size=self.beam_size,
best_of=1,
condition_on_previous_text=False,
vad_filter=False,
without_timestamps=True,
temperature=0.0,
)
return " ".join(segment.text.strip() for segment in segments if segment.text.strip()).strip()
finally:
if temp_path:
try:
os.unlink(temp_path)
except OSError:
pass
def main() -> int:
try:
worker = SttWorker()
except Exception as exc:
log("failed to initialize local STT worker")
log("run `bun run setup:local-ai` first if dependencies are missing")
log("".join(traceback.format_exception(exc)))
return 1
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
request_id = int(request["id"])
method = request["method"]
params = request.get("params", {})
if method == "ping":
write_response(request_id, True, {"ready": True})
continue
if method != "transcribe":
raise ValueError(f"unsupported method: {method}")
text = worker.transcribe(
audio_base64=str(params.get("audio_base64", "")),
language=str(params.get("language") or "").strip() or None,
)
write_response(request_id, True, {"text": text})
except Exception as exc:
error_text = "".join(traceback.format_exception_only(type(exc), exc)).strip()
write_response(request_id, False, error=error_text)
return 0
if __name__ == "__main__":
raise SystemExit(main())