Files
realtime_voice_bot/python/local_stt_worker.py

238 lines
7.1 KiB
Python

import base64
import glob
import json
import os
from pathlib import Path
import site
import sys
import sysconfig
import tempfile
import traceback
import wave
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
def log(message: str) -> None:
print(message, file=sys.stderr, flush=True)
def write_response(request_id: int, ok: bool, result=None, error: str | None = None) -> None:
payload = {
"id": request_id,
"ok": ok,
}
if ok:
payload["result"] = result
else:
payload["error"] = error or "unknown error"
sys.stdout.write(json.dumps(payload, ensure_ascii=False) + "\n")
sys.stdout.flush()
def resolve_device() -> str:
raw = os.environ.get("LOCAL_STT_DEVICE", "auto").strip().lower()
if raw and raw != "auto":
return raw
try:
import ctranslate2
if ctranslate2.get_cuda_device_count() > 0:
return "cuda"
except Exception:
pass
return "cpu"
def configure_windows_cuda_runtime() -> None:
if os.name != "nt":
return
candidate_dirs: list[str] = []
for key in ("CUDA_PATH", "CUDA_HOME"):
value = os.environ.get(key)
if value:
candidate_dirs.append(os.path.join(value, "bin"))
for key, value in os.environ.items():
if key.startswith("CUDA_PATH_V") and value:
candidate_dirs.append(os.path.join(value, "bin"))
candidate_dirs.extend(
sorted(glob.glob(r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin"), reverse=True)
)
site_roots: list[str] = []
try:
site_roots.extend(site.getsitepackages())
except Exception:
pass
try:
site_roots.append(site.getusersitepackages())
except Exception:
pass
for key in ("purelib", "platlib"):
value = sysconfig.get_paths().get(key)
if value:
site_roots.append(value)
for root in site_roots:
nvidia_root = Path(root) / "nvidia"
if not nvidia_root.is_dir():
continue
for pattern in ("**/cublas64_12.dll", "**/cudnn*.dll", "**/cudart64*.dll"):
for dll_path in nvidia_root.glob(pattern):
candidate_dirs.append(str(dll_path.parent))
unique_dirs: list[str] = []
for candidate in candidate_dirs:
normalized = os.path.normpath(candidate)
if not os.path.isdir(normalized):
continue
if normalized in unique_dirs:
continue
unique_dirs.append(normalized)
for directory in unique_dirs:
try:
os.add_dll_directory(directory)
except (AttributeError, FileNotFoundError, OSError):
pass
if unique_dirs:
existing_path = os.environ.get("PATH", "")
os.environ["PATH"] = os.pathsep.join(unique_dirs + [existing_path])
log(f"configured CUDA DLL search paths: {', '.join(unique_dirs)}")
def resolve_compute_type(device: str) -> str:
raw = os.environ.get("LOCAL_STT_COMPUTE_TYPE", "auto").strip().lower()
if raw and raw != "auto":
return raw
if device == "cuda":
return "int8_float16"
return "int8"
class SttWorker:
def __init__(self) -> None:
configure_windows_cuda_runtime()
from faster_whisper import WhisperModel
self.model_name = os.environ.get("LOCAL_STT_MODEL", "tiny").strip() or "tiny"
requested_device = resolve_device()
requested_compute_type = resolve_compute_type(requested_device)
self.beam_size = int(os.environ.get("LOCAL_STT_BEAM_SIZE", "1"))
auto_requested = os.environ.get("LOCAL_STT_DEVICE", "auto").strip().lower() in {"", "auto"}
try:
self.model = WhisperModel(
self.model_name,
device=requested_device,
compute_type=requested_compute_type,
)
self.device = requested_device
self.compute_type = requested_compute_type
except RuntimeError as exc:
lowered = str(exc).lower()
should_fallback = auto_requested and requested_device == "cuda" and any(
token in lowered for token in ("cublas", "cudnn", "cuda")
)
if not should_fallback:
raise
log("CUDA runtime is incomplete; falling back to CPU STT")
self.model = WhisperModel(
self.model_name,
device="cpu",
compute_type=resolve_compute_type("cpu"),
)
self.device = "cpu"
self.compute_type = resolve_compute_type("cpu")
log(
f"local-stt ready model={self.model_name} device={self.device} compute={self.compute_type} beam={self.beam_size}"
)
def transcribe(self, audio_base64: str, language: str | None) -> str:
pcm_bytes = base64.b64decode(audio_base64)
temp_path = ""
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as handle:
temp_path = handle.name
with wave.open(temp_path, "wb") as wav_file:
wav_file.setnchannels(1)
wav_file.setsampwidth(2)
wav_file.setframerate(16000)
wav_file.writeframes(pcm_bytes)
segments, _info = self.model.transcribe(
temp_path,
language=language,
beam_size=self.beam_size,
best_of=1,
condition_on_previous_text=False,
vad_filter=False,
without_timestamps=True,
temperature=0.0,
)
return " ".join(segment.text.strip() for segment in segments if segment.text.strip()).strip()
finally:
if temp_path:
try:
os.unlink(temp_path)
except OSError:
pass
def main() -> int:
try:
worker = SttWorker()
except Exception as exc:
log("failed to initialize local STT worker")
log("run `bun run setup:local-ai` first if dependencies are missing")
log("".join(traceback.format_exception(exc)))
return 1
for line in sys.stdin:
line = line.strip()
if not line:
continue
try:
request = json.loads(line)
request_id = int(request["id"])
method = request["method"]
params = request.get("params", {})
if method == "ping":
write_response(request_id, True, {"ready": True})
continue
if method != "transcribe":
raise ValueError(f"unsupported method: {method}")
text = worker.transcribe(
audio_base64=str(params.get("audio_base64", "")),
language=str(params.get("language") or "").strip() or None,
)
write_response(request_id, True, {"text": text})
except Exception as exc:
error_text = "".join(traceback.format_exception_only(type(exc), exc)).strip()
write_response(request_id, False, error=error_text)
return 0
if __name__ == "__main__":
raise SystemExit(main())