Switch local TTS to Kokoro ONNX
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import traceback
|
||||
import wave
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
os.environ.setdefault("PYTHONIOENCODING", "utf-8")
|
||||
@@ -27,53 +30,61 @@ def write_response(request_id: int, ok: bool, result=None, error: str | None = N
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def normalize_lang(raw: str) -> str:
|
||||
lowered = raw.strip().lower()
|
||||
if lowered in {"kr", "ko-kr"}:
|
||||
return "ko"
|
||||
return lowered or "ko"
|
||||
|
||||
|
||||
def normalize_voice(raw: str) -> str:
|
||||
value = raw.strip()
|
||||
if value.upper() in {"KR", "KO"} or not value:
|
||||
return "af_heart"
|
||||
return value
|
||||
|
||||
|
||||
class TtsWorker:
|
||||
def __init__(self) -> None:
|
||||
from melo.api import TTS
|
||||
from kokoro_onnx import Kokoro
|
||||
from misaki import ko
|
||||
|
||||
self.language = os.environ.get("LOCAL_TTS_LANGUAGE", "KR").strip() or "KR"
|
||||
self.speaker_key = os.environ.get("LOCAL_TTS_SPEAKER", "KR").strip() or "KR"
|
||||
self.device = os.environ.get("LOCAL_TTS_DEVICE", "auto").strip() or "auto"
|
||||
self.model_path = os.environ["LOCAL_TTS_MODEL_PATH"]
|
||||
self.voices_path = os.environ["LOCAL_TTS_VOICES_PATH"]
|
||||
self.language = normalize_lang(os.environ.get("LOCAL_TTS_LANGUAGE", "ko"))
|
||||
self.voice = normalize_voice(os.environ.get("LOCAL_TTS_SPEAKER", "af_heart"))
|
||||
self.speed = float(os.environ.get("LOCAL_TTS_SPEED", "1.12"))
|
||||
|
||||
self.model = TTS(language=self.language, device=self.device)
|
||||
speaker_ids = self.model.hps.data.spk2id
|
||||
self.speaker_id = speaker_ids.get(self.speaker_key)
|
||||
|
||||
if self.speaker_id is None:
|
||||
normalized = self.speaker_key.upper()
|
||||
self.speaker_id = speaker_ids.get(normalized)
|
||||
|
||||
if self.speaker_id is None:
|
||||
self.speaker_id = next(iter(speaker_ids.values()))
|
||||
self.g2p = ko.KOG2P()
|
||||
self.model = Kokoro(self.model_path, self.voices_path)
|
||||
|
||||
log(
|
||||
f"local-tts ready language={self.language} speaker={self.speaker_key} device={self.device} speed={self.speed}"
|
||||
f"local-tts ready model={os.path.basename(self.model_path)} voice={self.voice} language={self.language} speed={self.speed}"
|
||||
)
|
||||
|
||||
def synthesize(self, text: str) -> bytes:
|
||||
temp_path = ""
|
||||
phonemes, _tokens = self.g2p(text)
|
||||
samples, sample_rate = self.model.create(
|
||||
phonemes,
|
||||
voice=self.voice,
|
||||
speed=self.speed,
|
||||
lang="en-us",
|
||||
is_phonemes=True,
|
||||
)
|
||||
return build_wav_bytes(samples, sample_rate)
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as handle:
|
||||
temp_path = handle.name
|
||||
|
||||
self.model.tts_to_file(
|
||||
text,
|
||||
self.speaker_id,
|
||||
temp_path,
|
||||
speed=self.speed,
|
||||
quiet=True,
|
||||
)
|
||||
def build_wav_bytes(samples: np.ndarray, sample_rate: int) -> bytes:
|
||||
clipped = np.clip(samples, -1.0, 1.0)
|
||||
pcm = (clipped * 32767.0).astype(np.int16)
|
||||
buffer = io.BytesIO()
|
||||
|
||||
with open(temp_path, "rb") as handle:
|
||||
return handle.read()
|
||||
finally:
|
||||
if temp_path:
|
||||
try:
|
||||
os.unlink(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
with wave.open(buffer, "wb") as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm.tobytes())
|
||||
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
|
||||
Reference in New Issue
Block a user