168 lines
5.1 KiB
Python
168 lines
5.1 KiB
Python
import base64
|
|
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import traceback
|
|
import wave
|
|
|
|
|
|
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
|
|
|
|
|
|
def log(message: str) -> None:
|
|
print(message, file=sys.stderr, flush=True)
|
|
|
|
|
|
def write_response(request_id: int, ok: bool, result=None, error: str | None = None) -> None:
|
|
payload = {
|
|
"id": request_id,
|
|
"ok": ok,
|
|
}
|
|
if ok:
|
|
payload["result"] = result
|
|
else:
|
|
payload["error"] = error or "unknown error"
|
|
|
|
sys.stdout.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
|
|
def resolve_device() -> str:
|
|
raw = os.environ.get("LOCAL_STT_DEVICE", "auto").strip().lower()
|
|
if raw and raw != "auto":
|
|
return raw
|
|
|
|
try:
|
|
import ctranslate2
|
|
|
|
if ctranslate2.get_cuda_device_count() > 0:
|
|
return "cuda"
|
|
except Exception:
|
|
pass
|
|
|
|
return "cpu"
|
|
|
|
|
|
def resolve_compute_type(device: str) -> str:
|
|
raw = os.environ.get("LOCAL_STT_COMPUTE_TYPE", "auto").strip().lower()
|
|
if raw and raw != "auto":
|
|
return raw
|
|
if device == "cuda":
|
|
return "int8_float16"
|
|
return "int8"
|
|
|
|
|
|
class SttWorker:
|
|
def __init__(self) -> None:
|
|
from faster_whisper import WhisperModel
|
|
|
|
self.model_name = os.environ.get("LOCAL_STT_MODEL", "tiny").strip() or "tiny"
|
|
requested_device = resolve_device()
|
|
requested_compute_type = resolve_compute_type(requested_device)
|
|
self.beam_size = int(os.environ.get("LOCAL_STT_BEAM_SIZE", "1"))
|
|
auto_requested = os.environ.get("LOCAL_STT_DEVICE", "auto").strip().lower() in {"", "auto"}
|
|
|
|
try:
|
|
self.model = WhisperModel(
|
|
self.model_name,
|
|
device=requested_device,
|
|
compute_type=requested_compute_type,
|
|
)
|
|
self.device = requested_device
|
|
self.compute_type = requested_compute_type
|
|
except RuntimeError as exc:
|
|
lowered = str(exc).lower()
|
|
should_fallback = auto_requested and requested_device == "cuda" and any(
|
|
token in lowered for token in ("cublas", "cudnn", "cuda")
|
|
)
|
|
if not should_fallback:
|
|
raise
|
|
|
|
log("CUDA runtime is incomplete; falling back to CPU STT")
|
|
self.model = WhisperModel(
|
|
self.model_name,
|
|
device="cpu",
|
|
compute_type=resolve_compute_type("cpu"),
|
|
)
|
|
self.device = "cpu"
|
|
self.compute_type = resolve_compute_type("cpu")
|
|
|
|
log(
|
|
f"local-stt ready model={self.model_name} device={self.device} compute={self.compute_type} beam={self.beam_size}"
|
|
)
|
|
|
|
def transcribe(self, audio_base64: str, language: str | None) -> str:
|
|
pcm_bytes = base64.b64decode(audio_base64)
|
|
temp_path = ""
|
|
|
|
try:
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as handle:
|
|
temp_path = handle.name
|
|
|
|
with wave.open(temp_path, "wb") as wav_file:
|
|
wav_file.setnchannels(1)
|
|
wav_file.setsampwidth(2)
|
|
wav_file.setframerate(16000)
|
|
wav_file.writeframes(pcm_bytes)
|
|
|
|
segments, _info = self.model.transcribe(
|
|
temp_path,
|
|
language=language,
|
|
beam_size=self.beam_size,
|
|
best_of=1,
|
|
condition_on_previous_text=False,
|
|
vad_filter=False,
|
|
without_timestamps=True,
|
|
temperature=0.0,
|
|
)
|
|
return " ".join(segment.text.strip() for segment in segments if segment.text.strip()).strip()
|
|
finally:
|
|
if temp_path:
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
|
|
def main() -> int:
|
|
try:
|
|
worker = SttWorker()
|
|
except Exception as exc:
|
|
log("failed to initialize local STT worker")
|
|
log("run `bun run setup:local-ai` first if dependencies are missing")
|
|
log("".join(traceback.format_exception(exc)))
|
|
return 1
|
|
|
|
for line in sys.stdin:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
try:
|
|
request = json.loads(line)
|
|
request_id = int(request["id"])
|
|
method = request["method"]
|
|
params = request.get("params", {})
|
|
|
|
if method == "ping":
|
|
write_response(request_id, True, {"ready": True})
|
|
continue
|
|
if method != "transcribe":
|
|
raise ValueError(f"unsupported method: {method}")
|
|
|
|
text = worker.transcribe(
|
|
audio_base64=str(params.get("audio_base64", "")),
|
|
language=str(params.get("language") or "").strip() or None,
|
|
)
|
|
write_response(request_id, True, {"text": text})
|
|
except Exception as exc:
|
|
error_text = "".join(traceback.format_exception_only(type(exc), exc)).strip()
|
|
write_response(request_id, False, error=error_text)
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|