Add full STT LLM TTS test mode
This commit is contained in:
110
docker/melotts/melo_tts_worker.py
Normal file
110
docker/melotts/melo_tts_worker.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from melo.api import TTS
|
||||
|
||||
|
||||
LANGUAGE = os.getenv("TTS_LANGUAGE", "KR")
|
||||
SPEAKER = os.getenv("TTS_SPEAKER", "KR")
|
||||
DEVICE = os.getenv("TTS_DEVICE", "cpu")
|
||||
SPEED = float(os.getenv("TTS_SPEED", "1.18"))
|
||||
SDP_RATIO = float(os.getenv("TTS_SDP_RATIO", "0.22"))
|
||||
NOISE_SCALE = float(os.getenv("TTS_NOISE_SCALE", "0.55"))
|
||||
NOISE_SCALE_W = float(os.getenv("TTS_NOISE_SCALE_W", "0.75"))
|
||||
|
||||
_MODEL = None
|
||||
_SPEAKER_ID = None
|
||||
|
||||
|
||||
def load_model():
|
||||
global _MODEL
|
||||
global _SPEAKER_ID
|
||||
|
||||
if _MODEL is not None and _SPEAKER_ID is not None:
|
||||
return _MODEL, _SPEAKER_ID
|
||||
|
||||
model = TTS(language=LANGUAGE, device=DEVICE)
|
||||
speaker_ids = model.hps.data.spk2id
|
||||
|
||||
if SPEAKER not in speaker_ids:
|
||||
supported = ", ".join(sorted(speaker_ids.keys()))
|
||||
raise RuntimeError(f"지원하지 않는 speaker 입니다: {SPEAKER}. 사용 가능: {supported}")
|
||||
|
||||
_MODEL = model
|
||||
_SPEAKER_ID = speaker_ids[SPEAKER]
|
||||
return _MODEL, _SPEAKER_ID
|
||||
|
||||
|
||||
def handle_ping():
|
||||
model, speaker_id = load_model()
|
||||
return {
|
||||
"language": LANGUAGE,
|
||||
"speaker": SPEAKER,
|
||||
"speaker_id": speaker_id,
|
||||
"device": DEVICE,
|
||||
"speed": SPEED,
|
||||
"sdp_ratio": SDP_RATIO,
|
||||
"noise_scale": NOISE_SCALE,
|
||||
"noise_scale_w": NOISE_SCALE_W,
|
||||
"speaker_count": len(model.hps.data.spk2id),
|
||||
}
|
||||
|
||||
|
||||
def handle_synthesize(params):
|
||||
text = str(params["text"]).strip()
|
||||
output_path = Path(str(params["output_path"]))
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model, speaker_id = load_model()
|
||||
model.tts_to_file(
|
||||
text,
|
||||
speaker_id,
|
||||
str(output_path),
|
||||
speed=SPEED,
|
||||
sdp_ratio=SDP_RATIO,
|
||||
noise_scale=NOISE_SCALE,
|
||||
noise_scale_w=NOISE_SCALE_W,
|
||||
)
|
||||
|
||||
return {
|
||||
"output_path": str(output_path),
|
||||
"text_length": len(text),
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
for raw_line in sys.stdin:
|
||||
line = raw_line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
payload = json.loads(line)
|
||||
request_id = str(payload["id"])
|
||||
method = payload["method"]
|
||||
params = payload.get("params", {})
|
||||
|
||||
if method == "ping":
|
||||
result = handle_ping()
|
||||
elif method == "synthesize":
|
||||
result = handle_synthesize(params)
|
||||
else:
|
||||
raise RuntimeError(f"알 수 없는 method 입니다: {method}")
|
||||
|
||||
sys.stdout.write(json.dumps({"id": request_id, "result": result}, ensure_ascii=False) + "\n")
|
||||
sys.stdout.flush()
|
||||
except Exception as error:
|
||||
request_id = "unknown"
|
||||
try:
|
||||
request_id = str(payload.get("id", "unknown"))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
sys.stdout.write(json.dumps({"id": request_id, "error": str(error)}, ensure_ascii=False) + "\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user