Install and wire CUDA runtime for Windows STT
This commit is contained in:
@@ -2,13 +2,66 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import site
|
||||
import traceback
|
||||
from typing import Any
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
|
||||
def configure_windows_dll_search_paths() -> list[str]:
|
||||
if sys.platform != "win32":
|
||||
return []
|
||||
|
||||
candidates: list[Path] = []
|
||||
executable_dir = Path(sys.executable).resolve().parent
|
||||
venv_root = executable_dir.parent
|
||||
candidates.extend(
|
||||
[
|
||||
venv_root / "Lib" / "site-packages" / "nvidia" / "cublas" / "bin",
|
||||
venv_root / "Lib" / "site-packages" / "nvidia" / "cudnn" / "bin",
|
||||
]
|
||||
)
|
||||
|
||||
for package_path in site.getsitepackages():
|
||||
base = Path(package_path)
|
||||
candidates.extend(
|
||||
[
|
||||
base / "nvidia" / "cublas" / "bin",
|
||||
base / "nvidia" / "cudnn" / "bin",
|
||||
]
|
||||
)
|
||||
|
||||
added: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for candidate in candidates:
|
||||
normalized = str(candidate.resolve())
|
||||
if normalized in seen:
|
||||
continue
|
||||
seen.add(normalized)
|
||||
if not candidate.exists():
|
||||
continue
|
||||
|
||||
os.add_dll_directory(normalized)
|
||||
if normalized not in os.environ.get("PATH", ""):
|
||||
os.environ["PATH"] = normalized + os.pathsep + os.environ.get("PATH", "")
|
||||
added.append(normalized)
|
||||
|
||||
return added
|
||||
|
||||
|
||||
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
|
||||
CONFIGURED_DLL_PATHS = configure_windows_dll_search_paths()
|
||||
if CONFIGURED_DLL_PATHS:
|
||||
print(
|
||||
f"configured CUDA DLL search paths: {', '.join(CONFIGURED_DLL_PATHS)}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def resolve_model() -> WhisperModel:
|
||||
model_name = os.environ.get("WHISPER_MODEL", "large-v3-turbo")
|
||||
requested_device = os.environ.get("WHISPER_DEVICE", "auto")
|
||||
|
||||
Reference in New Issue
Block a user