134 lines
3.8 KiB
Python
134 lines
3.8 KiB
Python
import base64
|
|
import json
|
|
import os
|
|
import sys
|
|
import traceback
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
from faster_whisper import WhisperModel
|
|
|
|
|
|
def resolve_model() -> WhisperModel:
|
|
model_name = os.environ.get("WHISPER_MODEL", "large-v3-turbo")
|
|
requested_device = os.environ.get("WHISPER_DEVICE", "auto")
|
|
requested_compute = os.environ.get("WHISPER_COMPUTE_TYPE", "auto")
|
|
|
|
attempts: list[tuple[str, str]] = []
|
|
if requested_device == "auto":
|
|
if requested_compute == "auto":
|
|
attempts.extend(
|
|
[
|
|
("cuda", "float16"),
|
|
("cuda", "int8_float16"),
|
|
("cpu", "int8"),
|
|
("cpu", "float32"),
|
|
]
|
|
)
|
|
else:
|
|
attempts.extend(
|
|
[
|
|
("cuda", requested_compute),
|
|
("cpu", requested_compute),
|
|
]
|
|
)
|
|
else:
|
|
if requested_compute == "auto":
|
|
compute = "float16" if requested_device == "cuda" else "int8"
|
|
else:
|
|
compute = requested_compute
|
|
attempts.append((requested_device, compute))
|
|
|
|
last_error: Exception | None = None
|
|
for device, compute_type in attempts:
|
|
try:
|
|
model = WhisperModel(model_name, device=device, compute_type=compute_type)
|
|
setattr(model, "_resolved_device", device)
|
|
setattr(model, "_resolved_compute_type", compute_type)
|
|
return model
|
|
except Exception as error: # noqa: BLE001
|
|
last_error = error
|
|
|
|
assert last_error is not None
|
|
raise last_error
|
|
|
|
|
|
MODEL = resolve_model()
|
|
LANGUAGE = os.environ.get("WHISPER_LANGUAGE", "ko")
|
|
BEAM_SIZE = int(os.environ.get("WHISPER_BEAM_SIZE", "1"))
|
|
|
|
|
|
def write(payload: dict[str, Any]) -> None:
|
|
sys.stdout.write(json.dumps(payload, ensure_ascii=False) + "\n")
|
|
sys.stdout.flush()
|
|
|
|
|
|
def transcribe_pcm16_base64(pcm16_base64: str) -> str:
|
|
audio_bytes = base64.b64decode(pcm16_base64)
|
|
audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
|
|
|
segments, _info = MODEL.transcribe(
|
|
audio,
|
|
language=LANGUAGE,
|
|
task="transcribe",
|
|
beam_size=BEAM_SIZE,
|
|
condition_on_previous_text=False,
|
|
vad_filter=False,
|
|
without_timestamps=True,
|
|
word_timestamps=False,
|
|
temperature=0.0,
|
|
)
|
|
|
|
text_parts: list[str] = []
|
|
for segment in segments:
|
|
if segment.text:
|
|
text_parts.append(segment.text.strip())
|
|
return " ".join(part for part in text_parts if part).strip()
|
|
|
|
|
|
for raw_line in sys.stdin:
|
|
line = raw_line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
request = json.loads(line)
|
|
request_id = request["id"]
|
|
method = request["method"]
|
|
params = request.get("params", {})
|
|
|
|
try:
|
|
if method == "ping":
|
|
write(
|
|
{
|
|
"id": request_id,
|
|
"result": {
|
|
"model": os.environ.get("WHISPER_MODEL", "large-v3-turbo"),
|
|
"device": getattr(MODEL, "_resolved_device", "unknown"),
|
|
"compute_type": getattr(MODEL, "_resolved_compute_type", "unknown"),
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
if method == "transcribe":
|
|
text = transcribe_pcm16_base64(params["pcm16_base64"])
|
|
write(
|
|
{
|
|
"id": request_id,
|
|
"result": {
|
|
"text": text,
|
|
},
|
|
}
|
|
)
|
|
continue
|
|
|
|
raise RuntimeError(f"unknown method: {method}")
|
|
except Exception as error: # noqa: BLE001
|
|
traceback.print_exc(file=sys.stderr)
|
|
write(
|
|
{
|
|
"id": request_id,
|
|
"error": f"{type(error).__name__}: {error}",
|
|
}
|
|
)
|