Add Discord-native hybrid front-end for Jarvis (bot + bridge)
Some checks failed
Release / semantic-release (push) Successful in 59s
tests / Unit tests (Linux, Python 3.11) (push) Successful in 13m45s
Release / build-linux (push) Failing after 7m47s
Release / build-windows (push) Has been cancelled
Release / build-macos (arm64, macos-latest) (push) Has been cancelled
Release / build-macos (x64, macos-15-intel) (push) Has been cancelled
Release / release-main (push) Has been cancelled
Release / release-develop (push) Has been cancelled
Some checks failed
Release / semantic-release (push) Successful in 59s
tests / Unit tests (Linux, Python 3.11) (push) Successful in 13m45s
Release / build-linux (push) Failing after 7m47s
Release / build-windows (push) Has been cancelled
Release / build-macos (arm64, macos-latest) (push) Has been cancelled
Release / build-macos (x64, macos-15-intel) (push) Has been cancelled
Release / release-main (push) Has been cancelled
Release / release-develop (push) Has been cancelled
Transform isair/jarvis into a Discord-controlled voice assistant running on the Ubuntu VNC desktop, keeping the mature ~39k-line Python brain intact. - bot/ (Node + bun, discord.js): /자비스 slash commands (ephemeral), voice channel join + voice receive/playback, pluggable VNC screen broadcast (selfbot live / noVNC / screenshot) - bridge/ (Python, Flask): wraps jarvis STT + run_reply_engine + Piper TTS behind a thin localhost HTTP API - .env.example, scripts/ (start_bridge/start_bot/dev), README rewrite, docs/language-comparison.md and docs/vnc-xfce-setup.md Language decision: hybrid (Python brain + Node/bun Discord layer) because Discord blocks bot video; native screen broadcast only works via a Node selfbot library.
This commit is contained in:
634
tests/test_piper_tts.py
Normal file
634
tests/test_piper_tts.py
Normal file
@@ -0,0 +1,634 @@
|
||||
"""Tests for Piper TTS implementation."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class TestPiperTTSInterface:
|
||||
"""Tests for PiperTTS interface compliance."""
|
||||
|
||||
def test_has_required_methods(self):
|
||||
"""PiperTTS should have the same interface as TextToSpeech."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
# Create instance with TTS disabled (no model needed)
|
||||
tts = PiperTTS(enabled=False)
|
||||
|
||||
# Check required methods exist
|
||||
assert hasattr(tts, "start")
|
||||
assert callable(tts.start)
|
||||
|
||||
assert hasattr(tts, "stop")
|
||||
assert callable(tts.stop)
|
||||
|
||||
assert hasattr(tts, "speak")
|
||||
assert callable(tts.speak)
|
||||
|
||||
assert hasattr(tts, "interrupt")
|
||||
assert callable(tts.interrupt)
|
||||
|
||||
assert hasattr(tts, "is_speaking")
|
||||
assert callable(tts.is_speaking)
|
||||
|
||||
assert hasattr(tts, "get_last_spoken_text")
|
||||
assert callable(tts.get_last_spoken_text)
|
||||
|
||||
def test_initialization_disabled(self):
|
||||
"""PiperTTS should handle disabled state gracefully."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=False)
|
||||
|
||||
# Should not crash when disabled
|
||||
tts.start()
|
||||
tts.speak("test text")
|
||||
assert tts.is_speaking() is False
|
||||
tts.interrupt()
|
||||
tts.stop()
|
||||
|
||||
def test_initialization_with_all_parameters(self):
|
||||
"""PiperTTS should accept all configuration parameters."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(
|
||||
enabled=True,
|
||||
voice="test-voice", # Interface compatibility
|
||||
rate=200, # Interface compatibility
|
||||
model_path="/path/to/model.onnx",
|
||||
speaker=0,
|
||||
length_scale=1.2,
|
||||
noise_scale=0.5,
|
||||
noise_w=0.7,
|
||||
sentence_silence=0.3,
|
||||
)
|
||||
|
||||
# Verify parameters are stored
|
||||
assert tts.enabled is True
|
||||
assert tts.voice == "test-voice"
|
||||
assert tts.rate == 200
|
||||
assert tts.model_path == "/path/to/model.onnx"
|
||||
assert tts.speaker == 0
|
||||
assert tts.length_scale == 1.2
|
||||
assert tts.noise_scale == 0.5
|
||||
assert tts.noise_w == 0.7
|
||||
assert tts.sentence_silence == 0.3
|
||||
|
||||
|
||||
class TestPiperTTSErrorHandling:
|
||||
"""Tests for PiperTTS error handling."""
|
||||
|
||||
def test_missing_model_with_failed_download(self, tmp_path):
|
||||
"""PiperTTS should handle failed download gracefully."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
from unittest.mock import patch
|
||||
|
||||
# Use a non-existent custom path to force download attempt
|
||||
custom_model = str(tmp_path / "nonexistent-voice.onnx")
|
||||
tts = PiperTTS(enabled=True, model_path=custom_model)
|
||||
|
||||
# Mock the download to fail
|
||||
with patch("src.jarvis.output.tts._download_piper_voice", return_value=None):
|
||||
result = tts._ensure_initialized()
|
||||
assert result is False
|
||||
assert tts._init_error is not None
|
||||
assert "download" in tts._init_error.lower() or "failed" in tts._init_error.lower()
|
||||
|
||||
def test_nonexistent_model_file_with_failed_download(self):
|
||||
"""PiperTTS should handle nonexistent model file gracefully when download fails."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
from unittest.mock import patch
|
||||
|
||||
tts = PiperTTS(enabled=True, model_path="/nonexistent/path/model.onnx")
|
||||
|
||||
# Mock the download to fail
|
||||
with patch("src.jarvis.output.tts._download_piper_voice", return_value=None):
|
||||
result = tts._ensure_initialized()
|
||||
assert result is False
|
||||
assert tts._init_error is not None
|
||||
|
||||
def test_missing_config_json(self, tmp_path):
|
||||
"""PiperTTS should require .onnx.json config file."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
from unittest.mock import patch
|
||||
|
||||
# Create a fake model file but no config
|
||||
model_file = tmp_path / "custom-voice.onnx"
|
||||
model_file.write_text("fake model")
|
||||
|
||||
tts = PiperTTS(enabled=True, model_path=str(model_file))
|
||||
|
||||
# Mock download to fail (since config doesn't exist)
|
||||
with patch("src.jarvis.output.tts._download_piper_voice", return_value=None):
|
||||
result = tts._ensure_initialized()
|
||||
assert result is False
|
||||
assert tts._init_error is not None
|
||||
|
||||
def test_user_path_expansion(self):
|
||||
"""PiperTTS should expand ~ in model path."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
from unittest.mock import patch
|
||||
import os
|
||||
|
||||
tts = PiperTTS(enabled=True, model_path="~/nonexistent/model.onnx")
|
||||
|
||||
# Mock download to fail
|
||||
with patch("src.jarvis.output.tts._download_piper_voice", return_value=None):
|
||||
tts._ensure_initialized()
|
||||
|
||||
# The error should reference the expanded path (with home directory)
|
||||
# not the literal ~
|
||||
if tts._init_error:
|
||||
# Either the path was expanded, or we got a different error
|
||||
# Both are acceptable as long as it didn't crash
|
||||
pass
|
||||
|
||||
def test_explicit_model_path_skips_default(self, tmp_path):
|
||||
"""When explicit model_path is given, don't use default."""
|
||||
from src.jarvis.output.tts import PiperTTS, _get_default_piper_model_path
|
||||
from unittest.mock import patch
|
||||
|
||||
custom_path = str(tmp_path / "custom-voice.onnx")
|
||||
tts = PiperTTS(enabled=True, model_path=custom_path)
|
||||
|
||||
# Mock download to return the custom path
|
||||
with patch("src.jarvis.output.tts._download_piper_voice", return_value=None):
|
||||
tts._ensure_initialized()
|
||||
|
||||
# Error should reference the custom path, not default
|
||||
if tts._init_error:
|
||||
default_path = _get_default_piper_model_path()
|
||||
# Should not be using the default path
|
||||
assert "custom-voice" in tts._init_error or "download" in tts._init_error.lower()
|
||||
|
||||
|
||||
class TestPiperTTSWithMocking:
|
||||
"""Tests for PiperTTS with mocked Piper library."""
|
||||
|
||||
def test_initialization_checks_both_files(self, tmp_path):
|
||||
"""PiperTTS should check both .onnx and .onnx.json files exist."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
from unittest.mock import patch
|
||||
|
||||
# Create model file but not config
|
||||
model_file = tmp_path / "test-voice.onnx"
|
||||
model_file.write_text("fake model")
|
||||
|
||||
tts = PiperTTS(enabled=True, model_path=str(model_file))
|
||||
|
||||
# Mock download to fail
|
||||
with patch("src.jarvis.output.tts._download_piper_voice", return_value=None):
|
||||
result = tts._ensure_initialized()
|
||||
|
||||
assert result is False
|
||||
assert tts._init_error is not None
|
||||
|
||||
@patch("src.jarvis.output.tts.os.path.exists", return_value=True)
|
||||
def test_piper_import_error_handling(self, mock_exists):
|
||||
"""PiperTTS should handle missing piper-tts library gracefully."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
with patch.dict("sys.modules", {"piper": None, "piper.voice": None}):
|
||||
# Force reimport to trigger import error
|
||||
tts = PiperTTS(enabled=True, model_path="/fake/model.onnx")
|
||||
|
||||
# Clear any previous initialization state
|
||||
tts._initialized = False
|
||||
tts._voice = None
|
||||
tts._init_error = None
|
||||
|
||||
# Mock the import to raise ImportError
|
||||
with patch(
|
||||
"src.jarvis.output.tts.PiperTTS._ensure_initialized",
|
||||
wraps=tts._ensure_initialized,
|
||||
):
|
||||
# Manually trigger what would happen with import error
|
||||
tts._init_error = "piper-tts not installed"
|
||||
tts._initialized = True
|
||||
result = tts._ensure_initialized()
|
||||
|
||||
# Should have caught the error
|
||||
assert tts._init_error is not None
|
||||
|
||||
def test_speak_queues_text(self):
|
||||
"""PiperTTS.speak should queue text for processing."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=True, model_path="/fake/model.onnx")
|
||||
|
||||
# Don't actually start the thread
|
||||
tts.speak("Hello world")
|
||||
|
||||
# Text should be in queue (may have been preprocessed)
|
||||
assert not tts._q.empty()
|
||||
|
||||
def test_speak_does_nothing_when_disabled(self):
|
||||
"""PiperTTS.speak should do nothing when disabled."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=False)
|
||||
tts.speak("Hello world")
|
||||
|
||||
# Queue should be empty
|
||||
assert tts._q.empty()
|
||||
|
||||
def test_speak_does_nothing_for_empty_text(self):
|
||||
"""PiperTTS.speak should do nothing for empty text."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=True, model_path="/fake/model.onnx")
|
||||
tts.speak("")
|
||||
tts.speak(" ")
|
||||
|
||||
# Queue should be empty
|
||||
assert tts._q.empty()
|
||||
|
||||
def test_interrupt_sets_flag(self):
|
||||
"""PiperTTS.interrupt should set the interrupt flag."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=True)
|
||||
|
||||
assert not tts._should_interrupt.is_set()
|
||||
tts.interrupt()
|
||||
assert tts._should_interrupt.is_set()
|
||||
|
||||
def test_is_speaking_returns_event_state(self):
|
||||
"""PiperTTS.is_speaking should return the speaking event state."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=True)
|
||||
|
||||
assert tts.is_speaking() is False
|
||||
|
||||
tts._is_speaking.set()
|
||||
assert tts.is_speaking() is True
|
||||
|
||||
tts._is_speaking.clear()
|
||||
assert tts.is_speaking() is False
|
||||
|
||||
def test_get_last_spoken_text_returns_stored_text(self):
|
||||
"""PiperTTS.get_last_spoken_text should return the last spoken text."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=True)
|
||||
|
||||
assert tts.get_last_spoken_text() == ""
|
||||
|
||||
tts._last_spoken_text = "Hello world"
|
||||
assert tts.get_last_spoken_text() == "Hello world"
|
||||
|
||||
|
||||
class TestPiperTTSFactory:
|
||||
"""Tests for the create_tts_engine factory function."""
|
||||
|
||||
def test_creates_piper_engine(self):
|
||||
"""create_tts_engine should create PiperTTS for engine='piper'."""
|
||||
from src.jarvis.output.tts import create_tts_engine, PiperTTS
|
||||
|
||||
tts = create_tts_engine(engine="piper", enabled=False)
|
||||
assert isinstance(tts, PiperTTS)
|
||||
|
||||
def test_creates_piper_engine_case_insensitive(self):
|
||||
"""create_tts_engine should handle 'PIPER', 'Piper', etc."""
|
||||
from src.jarvis.output.tts import create_tts_engine, PiperTTS
|
||||
|
||||
tts1 = create_tts_engine(engine="PIPER", enabled=False)
|
||||
tts2 = create_tts_engine(engine="Piper", enabled=False)
|
||||
|
||||
assert isinstance(tts1, PiperTTS)
|
||||
assert isinstance(tts2, PiperTTS)
|
||||
|
||||
def test_passes_piper_parameters(self):
|
||||
"""create_tts_engine should pass all Piper parameters."""
|
||||
from src.jarvis.output.tts import create_tts_engine, PiperTTS
|
||||
|
||||
tts = create_tts_engine(
|
||||
engine="piper",
|
||||
enabled=True,
|
||||
voice="test",
|
||||
rate=200,
|
||||
piper_model_path="/path/to/model.onnx",
|
||||
piper_speaker=1,
|
||||
piper_length_scale=0.9,
|
||||
piper_noise_scale=0.5,
|
||||
piper_noise_w=0.6,
|
||||
piper_sentence_silence=0.25,
|
||||
)
|
||||
|
||||
assert isinstance(tts, PiperTTS)
|
||||
assert tts.model_path == "/path/to/model.onnx"
|
||||
assert tts.speaker == 1
|
||||
assert tts.length_scale == 0.9
|
||||
assert tts.noise_scale == 0.5
|
||||
assert tts.noise_w == 0.6
|
||||
assert tts.sentence_silence == 0.25
|
||||
|
||||
def test_default_engine_is_piper(self):
|
||||
"""create_tts_engine should default to Piper TTS."""
|
||||
from src.jarvis.output.tts import create_tts_engine, PiperTTS
|
||||
|
||||
tts = create_tts_engine(enabled=False)
|
||||
assert isinstance(tts, PiperTTS)
|
||||
|
||||
def test_unknown_engine_falls_back_to_piper(self):
|
||||
"""create_tts_engine with unknown engine should create PiperTTS."""
|
||||
from src.jarvis.output.tts import create_tts_engine, PiperTTS
|
||||
|
||||
tts = create_tts_engine(engine="unknown", enabled=False)
|
||||
assert isinstance(tts, PiperTTS)
|
||||
|
||||
def test_chatterbox_engine_still_works(self):
|
||||
"""create_tts_engine should still create ChatterboxTTS."""
|
||||
from src.jarvis.output.tts import create_tts_engine, ChatterboxTTS
|
||||
|
||||
tts = create_tts_engine(engine="chatterbox", enabled=False)
|
||||
assert isinstance(tts, ChatterboxTTS)
|
||||
|
||||
|
||||
class TestPiperTTSAutoDownload:
|
||||
"""Tests for Piper TTS auto-download functionality."""
|
||||
|
||||
def test_get_default_model_path(self):
|
||||
"""Default model path should be in ~/.local/share/jarvis/models/piper/."""
|
||||
from src.jarvis.output.tts import _get_default_piper_model_path, PIPER_DEFAULT_VOICE
|
||||
|
||||
path = _get_default_piper_model_path()
|
||||
|
||||
assert PIPER_DEFAULT_VOICE in path
|
||||
assert path.endswith(".onnx")
|
||||
assert "jarvis" in path
|
||||
assert "piper" in path
|
||||
|
||||
def test_get_piper_models_dir(self):
|
||||
"""Models directory should be created under jarvis data dir."""
|
||||
from src.jarvis.output.tts import _get_piper_models_dir
|
||||
|
||||
models_dir = _get_piper_models_dir()
|
||||
|
||||
assert models_dir.exists()
|
||||
assert "jarvis" in str(models_dir)
|
||||
assert "piper" in str(models_dir)
|
||||
|
||||
def test_piper_uses_default_when_no_path(self):
|
||||
"""PiperTTS should use default model path when none configured."""
|
||||
from src.jarvis.output.tts import PiperTTS, _get_default_piper_model_path
|
||||
|
||||
tts = PiperTTS(enabled=True, model_path=None)
|
||||
|
||||
# model_path starts as None
|
||||
assert tts.model_path is None
|
||||
|
||||
# But initialization should use the default
|
||||
# (we don't actually init here to avoid downloads in tests)
|
||||
|
||||
def test_default_voice_is_reasonable(self):
|
||||
"""Default voice should be a reasonable choice."""
|
||||
from src.jarvis.output.tts import PIPER_DEFAULT_VOICE
|
||||
|
||||
# Should be British English
|
||||
assert PIPER_DEFAULT_VOICE.startswith("en_GB")
|
||||
# Should include quality indicator
|
||||
assert "medium" in PIPER_DEFAULT_VOICE or "high" in PIPER_DEFAULT_VOICE
|
||||
|
||||
|
||||
class TestPiperTTSConfig:
|
||||
"""Tests for Piper TTS configuration in Settings."""
|
||||
|
||||
def test_config_has_piper_fields(self):
|
||||
"""Settings dataclass should have all Piper TTS fields."""
|
||||
from src.jarvis.config import Settings
|
||||
import inspect
|
||||
|
||||
# Get the field names from Settings
|
||||
signature = inspect.signature(Settings)
|
||||
param_names = set(signature.parameters.keys())
|
||||
|
||||
# Check all Piper fields exist
|
||||
assert "tts_piper_model_path" in param_names
|
||||
assert "tts_piper_speaker" in param_names
|
||||
assert "tts_piper_length_scale" in param_names
|
||||
assert "tts_piper_noise_scale" in param_names
|
||||
assert "tts_piper_noise_w" in param_names
|
||||
assert "tts_piper_sentence_silence" in param_names
|
||||
|
||||
def test_default_config_has_piper_values(self):
|
||||
"""get_default_config should include Piper TTS defaults."""
|
||||
from src.jarvis.config import get_default_config
|
||||
|
||||
defaults = get_default_config()
|
||||
|
||||
assert "tts_piper_model_path" in defaults
|
||||
assert defaults["tts_piper_model_path"] is None
|
||||
|
||||
assert "tts_piper_speaker" in defaults
|
||||
assert defaults["tts_piper_speaker"] is None
|
||||
|
||||
assert "tts_piper_length_scale" in defaults
|
||||
assert defaults["tts_piper_length_scale"] == 0.65 # ~30% faster speech
|
||||
|
||||
assert "tts_piper_noise_scale" in defaults
|
||||
assert defaults["tts_piper_noise_scale"] == 0.8 # More expressive
|
||||
|
||||
assert "tts_piper_noise_w" in defaults
|
||||
assert defaults["tts_piper_noise_w"] == 1.0 # More lively
|
||||
|
||||
assert "tts_piper_sentence_silence" in defaults
|
||||
assert defaults["tts_piper_sentence_silence"] == 0.2
|
||||
|
||||
def test_tts_engine_defaults_to_piper(self):
|
||||
"""tts_engine should default to 'piper'."""
|
||||
from src.jarvis.config import load_settings, get_default_config
|
||||
from unittest.mock import patch
|
||||
|
||||
# Check default config
|
||||
defaults = get_default_config()
|
||||
assert defaults["tts_engine"] == "piper"
|
||||
|
||||
# Mock empty config file - should use default
|
||||
with patch("src.jarvis.config._load_json", return_value={}):
|
||||
settings = load_settings()
|
||||
assert settings.tts_engine == "piper"
|
||||
|
||||
def test_tts_engine_migrates_system_to_piper(self):
|
||||
"""tts_engine 'system' should be auto-migrated to 'piper' for existing users."""
|
||||
from src.jarvis.config import load_settings
|
||||
from unittest.mock import patch
|
||||
|
||||
# Old config with system TTS (no _config_version = pre-migration)
|
||||
config_data = {"tts_engine": "system"}
|
||||
|
||||
with patch("src.jarvis.config._load_json", return_value=config_data):
|
||||
with patch("src.jarvis.config._save_json", return_value=True):
|
||||
settings = load_settings()
|
||||
# Should be migrated to piper
|
||||
assert settings.tts_engine == "piper"
|
||||
|
||||
def test_invalid_engine_falls_back_to_piper(self):
|
||||
"""Invalid tts_engine values should fall back to piper."""
|
||||
from src.jarvis.config import load_settings
|
||||
from unittest.mock import patch
|
||||
|
||||
# Config with invalid TTS engine
|
||||
config_data = {
|
||||
"tts_engine": "invalid_engine",
|
||||
"_config_version": 1
|
||||
}
|
||||
|
||||
with patch("src.jarvis.config._load_json", return_value=config_data):
|
||||
settings = load_settings()
|
||||
# Should fall back to piper
|
||||
assert settings.tts_engine == "piper"
|
||||
|
||||
def test_chatterbox_engine_preserved(self):
|
||||
"""tts_engine 'chatterbox' should be preserved."""
|
||||
from src.jarvis.config import load_settings
|
||||
from unittest.mock import patch
|
||||
|
||||
config_data = {
|
||||
"tts_engine": "chatterbox",
|
||||
"_config_version": 1
|
||||
}
|
||||
|
||||
with patch("src.jarvis.config._load_json", return_value=config_data):
|
||||
settings = load_settings()
|
||||
assert settings.tts_engine == "chatterbox"
|
||||
|
||||
|
||||
class TestPiperTTSThreadSafety:
|
||||
"""Tests for PiperTTS thread safety."""
|
||||
|
||||
def test_multiple_interrupts_safe(self):
|
||||
"""Multiple calls to interrupt should be safe."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=True)
|
||||
|
||||
# Should not crash with multiple interrupts
|
||||
for _ in range(10):
|
||||
tts.interrupt()
|
||||
|
||||
def test_start_stop_cycle(self):
|
||||
"""Start and stop should be safe to call multiple times."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=False) # Disabled to avoid actual model loading
|
||||
|
||||
# Multiple start/stop cycles should be safe
|
||||
for _ in range(3):
|
||||
tts.start()
|
||||
tts.stop()
|
||||
|
||||
def test_concurrent_speaks(self):
|
||||
"""Multiple threads calling speak should not crash."""
|
||||
from src.jarvis.output.tts import PiperTTS
|
||||
|
||||
tts = PiperTTS(enabled=True, model_path="/fake/model.onnx")
|
||||
|
||||
# Don't start the actual worker thread
|
||||
def speak_text():
|
||||
for _ in range(10):
|
||||
tts.speak("Hello world")
|
||||
|
||||
threads = [threading.Thread(target=speak_text) for _ in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should not crash, queue should have items
|
||||
# (actual number depends on timing)
|
||||
|
||||
|
||||
class TestPiperVoiceDownloadRetry:
|
||||
"""Tests for retry logic when HuggingFace returns 429 Too Many Requests."""
|
||||
|
||||
def test_429_retried_then_succeeds(self, tmp_path):
|
||||
"""Download retries on 429 and succeeds on subsequent attempt."""
|
||||
import requests
|
||||
from src.jarvis.output.tts import _download_piper_voice
|
||||
|
||||
call_count = {"onnx": 0, "json": 0}
|
||||
|
||||
def mock_get(url, **kwargs):
|
||||
resp = MagicMock()
|
||||
is_json = url.endswith(".json")
|
||||
key = "json" if is_json else "onnx"
|
||||
call_count[key] += 1
|
||||
|
||||
if call_count[key] == 1:
|
||||
# First call: 429
|
||||
http_err = requests.exceptions.HTTPError(
|
||||
response=MagicMock(status_code=429)
|
||||
)
|
||||
http_err.response = MagicMock(status_code=429)
|
||||
resp.raise_for_status.side_effect = http_err
|
||||
return resp
|
||||
|
||||
# Subsequent calls: success
|
||||
resp.raise_for_status.return_value = None
|
||||
resp.headers = {"content-length": "4"}
|
||||
resp.iter_content.return_value = [b"data"]
|
||||
return resp
|
||||
|
||||
with patch("requests.get", side_effect=mock_get):
|
||||
with patch("src.jarvis.output.tts._get_piper_models_dir", return_value=tmp_path):
|
||||
with patch("src.jarvis.output.tts.time.sleep") as mock_sleep:
|
||||
result = _download_piper_voice("en_GB-alan-medium")
|
||||
|
||||
assert result is not None
|
||||
assert (tmp_path / "en_GB-alan-medium.onnx").exists()
|
||||
# Verify exponential backoff: 2^1=2s for the onnx 429, 2^1=2s for the json 429
|
||||
sleep_values = [c.args[0] for c in mock_sleep.call_args_list]
|
||||
assert all(v == 2 for v in sleep_values)
|
||||
|
||||
def test_429_gives_up_after_max_retries(self, tmp_path):
|
||||
"""Download gives up after exhausting retries on persistent 429."""
|
||||
import requests
|
||||
from src.jarvis.output.tts import _download_piper_voice
|
||||
|
||||
def mock_get(url, **kwargs):
|
||||
resp = MagicMock()
|
||||
http_err = requests.exceptions.HTTPError(
|
||||
response=MagicMock(status_code=429)
|
||||
)
|
||||
http_err.response = MagicMock(status_code=429)
|
||||
resp.raise_for_status.side_effect = http_err
|
||||
return resp
|
||||
|
||||
with patch("requests.get", side_effect=mock_get):
|
||||
with patch("src.jarvis.output.tts._get_piper_models_dir", return_value=tmp_path):
|
||||
with patch("src.jarvis.output.tts.time.sleep") as mock_sleep:
|
||||
result = _download_piper_voice("en_GB-alan-medium")
|
||||
|
||||
assert result is None
|
||||
# Verify exponential backoff sequence: 2, 4, 8, 16
|
||||
sleep_values = [c.args[0] for c in mock_sleep.call_args_list]
|
||||
assert sleep_values == [2, 4, 8, 16]
|
||||
|
||||
def test_non_429_error_not_retried(self, tmp_path):
|
||||
"""Download does not retry on non-429 HTTP errors (e.g. 404)."""
|
||||
import requests
|
||||
from src.jarvis.output.tts import _download_piper_voice
|
||||
|
||||
get_call_count = 0
|
||||
|
||||
def mock_get(url, **kwargs):
|
||||
nonlocal get_call_count
|
||||
get_call_count += 1
|
||||
resp = MagicMock()
|
||||
http_err = requests.exceptions.HTTPError(
|
||||
response=MagicMock(status_code=404)
|
||||
)
|
||||
http_err.response = MagicMock(status_code=404)
|
||||
resp.raise_for_status.side_effect = http_err
|
||||
return resp
|
||||
|
||||
with patch("requests.get", side_effect=mock_get):
|
||||
with patch("src.jarvis.output.tts._get_piper_models_dir", return_value=tmp_path):
|
||||
result = _download_piper_voice("en_GB-alan-medium")
|
||||
|
||||
assert result is None
|
||||
# Should only call once for the onnx file (no retry)
|
||||
assert get_call_count == 1
|
||||
Reference in New Issue
Block a user