"""Chronos zero-shot 시계열 예측 어댑터. 모델: amazon/chronos-t5-small (46M, 빠르고 RTX 3070 Ti 에 충분히 들어감). 환경변수 CHRONOS_MODEL 로 base/large 로 바꿀 수 있음. 입력: 종가 시계열 (list[float], 최소 32 step). 출력: horizon 일 quantile forecast (q10/median/q90). lazy singleton 으로 첫 호출 시 모델 로드. 디바이스는 settings.model_device 따라. """ from __future__ import annotations import logging import os import threading from dataclasses import dataclass from app.config import settings logger = logging.getLogger(__name__) MODEL_NAME = os.environ.get("CHRONOS_MODEL", "amazon/chronos-t5-small") _lock = threading.Lock() _state: dict[str, object] = {"loaded": False, "pipe": None, "device": None} @dataclass class ChronosForecast: horizon: int median: list[float] q10: list[float] q90: list[float] samples: list[list[float]] # raw samples for ensemble downstream def _resolve_device() -> str: import torch # lazy pref = (settings.model_device or "auto").lower() if pref == "cuda": return "cuda" if torch.cuda.is_available() else "cpu" if pref == "cpu": return "cpu" return "cuda" if torch.cuda.is_available() else "cpu" def _load() -> None: global _state with _lock: if _state["loaded"]: return import torch from chronos import ChronosPipeline token = settings.huggingface_token or None if token: os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token) os.environ.setdefault("HF_TOKEN", token) device = _resolve_device() # dtype 선택: # - 이전엔 cuda 면 무조건 bf16 으로 갔는데, torch 2.3.1+cu121 사전빌드 wheel 이 # sm_86 (RTX 3070 Ti) 의 일부 T5 커널 binary 를 빠뜨려서 inference 첫 호출에 # "no kernel image is available for execution on the device" 발생. ping/load # 까지는 통과해서 진단이 까다로웠음 (실제 005930 케이스에서 관측). # - chronos-t5-small 은 46M params 라 fp32 로도 8GB VRAM 에 여유 충분, 속도 # 차이도 일봉 30일 예측에선 무시 가능. 호환성 우선해 default 를 fp32 로. # - 드라이버/torch 업그레이드 후 다시 bf16 시험하려면 .env 에 # CHRONOS_DTYPE=bf16 (또는 fp16) 두면 됨. dtype_pref = os.environ.get("CHRONOS_DTYPE", "fp32").lower() if device == "cuda" and dtype_pref == "bf16": dtype = torch.bfloat16 elif device == "cuda" and dtype_pref == "fp16": dtype = torch.float16 else: dtype = torch.float32 logger.info("loading Chronos %s on %s (dtype=%s)", MODEL_NAME, device, dtype) pipe = ChronosPipeline.from_pretrained( MODEL_NAME, device_map=device, torch_dtype=dtype, ) _state.update({"loaded": True, "pipe": pipe, "device": device}) def _reload_cpu() -> None: """현재 pipeline 을 폐기하고 CPU 로 강제 재로드. cuda 환경에서 'no kernel image is available for execution on the device' 같이 런타임에야 드러나는 GPU 비호환 에러가 났을 때 자동 폴백용. 한 번 폴백하면 다음 호출부터는 CPU 그대로 사용 (재시도 비용 회피).""" global _state import torch from chronos import ChronosPipeline with _lock: logger.warning("falling back to CPU for Chronos (GPU inference failed)") _state.update({"loaded": False, "pipe": None, "device": None}) pipe = ChronosPipeline.from_pretrained( MODEL_NAME, device_map="cpu", torch_dtype=torch.float32, ) _state.update({"loaded": True, "pipe": pipe, "device": "cpu"}) def forecast( series: list[float], *, horizon: int = 5, num_samples: int = 30, ) -> ChronosForecast: """series 의 마지막 시점 이후 horizon 일 예측. series 는 일봉 종가. 최소 32개 권장 (그보다 짧으면 Chronos 분위 안정성 떨어짐). """ if len(series) < 32: raise ValueError( f"series too short ({len(series)}) for Chronos forecast (need >=32)" ) _load() import numpy as np import torch def _do_predict(): pipe = _state["pipe"] context = torch.tensor([float(x) for x in series], dtype=torch.float32) with torch.no_grad(): return pipe.predict( context=context, prediction_length=horizon, num_samples=num_samples, ) try: samples = _do_predict() except RuntimeError as exc: # cuda 빌드/드라이버 미스매치는 inference 시점에야 드러나는 경우가 많음. # 'no kernel image is available' / 'CUDA error' 같은 신호 잡아서 CPU 로 폴백. msg = str(exc) if _state.get("device") == "cuda" and ( "no kernel image" in msg or "CUDA error" in msg or "CUBLAS" in msg ): _reload_cpu() samples = _do_predict() else: raise # samples: (1, num_samples, prediction_length) arr = samples[0].cpu().float().numpy() q10 = np.quantile(arr, 0.10, axis=0).tolist() median = np.quantile(arr, 0.50, axis=0).tolist() q90 = np.quantile(arr, 0.90, axis=0).tolist() return ChronosForecast( horizon=horizon, median=[float(x) for x in median], q10=[float(x) for x in q10], q90=[float(x) for x in q90], samples=[[float(x) for x in row] for row in arr.tolist()], ) def ping() -> dict[str, object]: try: _load() return {"status": "ok", "model": MODEL_NAME, "device": _state["device"]} except Exception as exc: # noqa: BLE001 return {"status": "failed", "model": MODEL_NAME, "error": str(exc)}