"""Chronos zero-shot 시계열 예측 어댑터. 모델: amazon/chronos-t5-small (46M, 빠르고 RTX 3070 Ti 에 충분히 들어감). 환경변수 CHRONOS_MODEL 로 base/large 로 바꿀 수 있음. 입력: 종가 시계열 (list[float], 최소 32 step). 출력: horizon 일 quantile forecast (q10/median/q90). lazy singleton 으로 첫 호출 시 모델 로드. 디바이스는 settings.model_device 따라. """ from __future__ import annotations import logging import os import threading from dataclasses import dataclass from app.config import settings logger = logging.getLogger(__name__) MODEL_NAME = os.environ.get("CHRONOS_MODEL", "amazon/chronos-t5-small") _lock = threading.Lock() _state: dict[str, object] = {"loaded": False, "pipe": None, "device": None} @dataclass class ChronosForecast: horizon: int median: list[float] q10: list[float] q90: list[float] samples: list[list[float]] # raw samples for ensemble downstream def _resolve_device() -> str: import torch # lazy pref = (settings.model_device or "auto").lower() if pref == "cuda": return "cuda" if torch.cuda.is_available() else "cpu" if pref == "cpu": return "cpu" return "cuda" if torch.cuda.is_available() else "cpu" def _load() -> None: global _state with _lock: if _state["loaded"]: return import torch from chronos import ChronosPipeline token = settings.huggingface_token or None if token: os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token) os.environ.setdefault("HF_TOKEN", token) device = _resolve_device() # bf16 은 RTX 30xx 이상에서 지원. cpu 에선 fp32. dtype = torch.bfloat16 if device == "cuda" else torch.float32 logger.info("loading Chronos %s on %s (dtype=%s)", MODEL_NAME, device, dtype) pipe = ChronosPipeline.from_pretrained( MODEL_NAME, device_map=device, torch_dtype=dtype, ) _state.update({"loaded": True, "pipe": pipe, "device": device}) def forecast( series: list[float], *, horizon: int = 5, num_samples: int = 30, ) -> ChronosForecast: """series 의 마지막 시점 이후 horizon 일 예측. series 는 일봉 종가. 최소 32개 권장 (그보다 짧으면 Chronos 분위 안정성 떨어짐). """ if len(series) < 32: raise ValueError( f"series too short ({len(series)}) for Chronos forecast (need >=32)" ) _load() import numpy as np import torch pipe = _state["pipe"] context = torch.tensor([float(x) for x in series], dtype=torch.float32) with torch.no_grad(): samples = pipe.predict( context=context, prediction_length=horizon, num_samples=num_samples, ) # samples: (1, num_samples, prediction_length) arr = samples[0].cpu().float().numpy() q10 = np.quantile(arr, 0.10, axis=0).tolist() median = np.quantile(arr, 0.50, axis=0).tolist() q90 = np.quantile(arr, 0.90, axis=0).tolist() return ChronosForecast( horizon=horizon, median=[float(x) for x in median], q10=[float(x) for x in q10], q90=[float(x) for x in q90], samples=[[float(x) for x in row] for row in arr.tolist()], ) def ping() -> dict[str, object]: try: _load() return {"status": "ok", "model": MODEL_NAME, "device": _state["device"]} except Exception as exc: # noqa: BLE001 return {"status": "failed", "model": MODEL_NAME, "error": str(exc)}