Install Windows CUDA runtime for STT
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
import base64
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import site
|
||||
import sys
|
||||
import sysconfig
|
||||
import tempfile
|
||||
import traceback
|
||||
import wave
|
||||
@@ -44,6 +48,71 @@ def resolve_device() -> str:
|
||||
return "cpu"
|
||||
|
||||
|
||||
def configure_windows_cuda_runtime() -> None:
|
||||
if os.name != "nt":
|
||||
return
|
||||
|
||||
candidate_dirs: list[str] = []
|
||||
|
||||
for key in ("CUDA_PATH", "CUDA_HOME"):
|
||||
value = os.environ.get(key)
|
||||
if value:
|
||||
candidate_dirs.append(os.path.join(value, "bin"))
|
||||
|
||||
for key, value in os.environ.items():
|
||||
if key.startswith("CUDA_PATH_V") and value:
|
||||
candidate_dirs.append(os.path.join(value, "bin"))
|
||||
|
||||
candidate_dirs.extend(
|
||||
sorted(glob.glob(r"C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v*\bin"), reverse=True)
|
||||
)
|
||||
|
||||
site_roots: list[str] = []
|
||||
try:
|
||||
site_roots.extend(site.getsitepackages())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
site_roots.append(site.getusersitepackages())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
for key in ("purelib", "platlib"):
|
||||
value = sysconfig.get_paths().get(key)
|
||||
if value:
|
||||
site_roots.append(value)
|
||||
|
||||
for root in site_roots:
|
||||
nvidia_root = Path(root) / "nvidia"
|
||||
if not nvidia_root.is_dir():
|
||||
continue
|
||||
|
||||
for pattern in ("**/cublas64_12.dll", "**/cudnn*.dll", "**/cudart64*.dll"):
|
||||
for dll_path in nvidia_root.glob(pattern):
|
||||
candidate_dirs.append(str(dll_path.parent))
|
||||
|
||||
unique_dirs: list[str] = []
|
||||
for candidate in candidate_dirs:
|
||||
normalized = os.path.normpath(candidate)
|
||||
if not os.path.isdir(normalized):
|
||||
continue
|
||||
if normalized in unique_dirs:
|
||||
continue
|
||||
unique_dirs.append(normalized)
|
||||
|
||||
for directory in unique_dirs:
|
||||
try:
|
||||
os.add_dll_directory(directory)
|
||||
except (AttributeError, FileNotFoundError, OSError):
|
||||
pass
|
||||
|
||||
if unique_dirs:
|
||||
existing_path = os.environ.get("PATH", "")
|
||||
os.environ["PATH"] = os.pathsep.join(unique_dirs + [existing_path])
|
||||
log(f"configured CUDA DLL search paths: {', '.join(unique_dirs)}")
|
||||
|
||||
|
||||
def resolve_compute_type(device: str) -> str:
|
||||
raw = os.environ.get("LOCAL_STT_COMPUTE_TYPE", "auto").strip().lower()
|
||||
if raw and raw != "auto":
|
||||
@@ -55,6 +124,7 @@ def resolve_compute_type(device: str) -> str:
|
||||
|
||||
class SttWorker:
|
||||
def __init__(self) -> None:
|
||||
configure_windows_cuda_runtime()
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
self.model_name = os.environ.get("LOCAL_STT_MODEL", "tiny").strip() or "tiny"
|
||||
|
||||
Reference in New Issue
Block a user