feat(phase-3): Chronos zero-shot 예측 + 피처 빌더
- backend/app/models/chronos.py: amazon/chronos-t5-small (env CHRONOS_MODEL
override 가능). lazy singleton, cuda + bf16 자동, q10/median/q90 + raw
samples 반환 (앙상블 가중평균용).
- backend/app/models/features.py: 종목별 학습/추론 피처 DataFrame.
OHLCV + TA(rsi/macd/atr/bb/sma/ema/vol_z) + 외인기관거래대금 + macro
(kospi/kosdaq/usdkrw/us10y + r1) + sentiment(v_sentiment_daily, 3d rolling).
학습용은 with_targets=True 로 y_close_h{1,3,5}, y_ret_h*, y_dir_h*
(±0.3% flat band) 추가.
- pyproject.toml: chronos-forecasting 1.4.1, accelerate 0.30.1, joblib 1.4.2.
이 단계까지는 코드만. 실제 모델 다운로드는 첫 ping/predict 호출 시점에.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
118
backend/app/models/chronos.py
Normal file
118
backend/app/models/chronos.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""Chronos zero-shot 시계열 예측 어댑터.
|
||||||
|
|
||||||
|
모델: amazon/chronos-t5-small (46M, 빠르고 RTX 3070 Ti 에 충분히 들어감).
|
||||||
|
환경변수 CHRONOS_MODEL 로 base/large 로 바꿀 수 있음.
|
||||||
|
|
||||||
|
입력: 종가 시계열 (list[float], 최소 32 step).
|
||||||
|
출력: horizon 일 quantile forecast (q10/median/q90).
|
||||||
|
|
||||||
|
lazy singleton 으로 첫 호출 시 모델 로드. 디바이스는 settings.model_device 따라.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_NAME = os.environ.get("CHRONOS_MODEL", "amazon/chronos-t5-small")
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_state: dict[str, object] = {"loaded": False, "pipe": None, "device": None}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChronosForecast:
|
||||||
|
horizon: int
|
||||||
|
median: list[float]
|
||||||
|
q10: list[float]
|
||||||
|
q90: list[float]
|
||||||
|
samples: list[list[float]] # raw samples for ensemble downstream
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_device() -> str:
|
||||||
|
import torch # lazy
|
||||||
|
|
||||||
|
pref = (settings.model_device or "auto").lower()
|
||||||
|
if pref == "cuda":
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if pref == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _load() -> None:
|
||||||
|
global _state
|
||||||
|
with _lock:
|
||||||
|
if _state["loaded"]:
|
||||||
|
return
|
||||||
|
import torch
|
||||||
|
from chronos import ChronosPipeline
|
||||||
|
|
||||||
|
token = settings.huggingface_token or None
|
||||||
|
if token:
|
||||||
|
os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token)
|
||||||
|
os.environ.setdefault("HF_TOKEN", token)
|
||||||
|
|
||||||
|
device = _resolve_device()
|
||||||
|
# bf16 은 RTX 30xx 이상에서 지원. cpu 에선 fp32.
|
||||||
|
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||||
|
logger.info("loading Chronos %s on %s (dtype=%s)", MODEL_NAME, device, dtype)
|
||||||
|
pipe = ChronosPipeline.from_pretrained(
|
||||||
|
MODEL_NAME,
|
||||||
|
device_map=device,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
_state.update({"loaded": True, "pipe": pipe, "device": device})
|
||||||
|
|
||||||
|
|
||||||
|
def forecast(
|
||||||
|
series: list[float],
|
||||||
|
*,
|
||||||
|
horizon: int = 5,
|
||||||
|
num_samples: int = 30,
|
||||||
|
) -> ChronosForecast:
|
||||||
|
"""series 의 마지막 시점 이후 horizon 일 예측.
|
||||||
|
|
||||||
|
series 는 일봉 종가. 최소 32개 권장 (그보다 짧으면 Chronos 분위 안정성 떨어짐).
|
||||||
|
"""
|
||||||
|
if len(series) < 32:
|
||||||
|
raise ValueError(
|
||||||
|
f"series too short ({len(series)}) for Chronos forecast (need >=32)"
|
||||||
|
)
|
||||||
|
_load()
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = _state["pipe"]
|
||||||
|
context = torch.tensor([float(x) for x in series], dtype=torch.float32)
|
||||||
|
with torch.no_grad():
|
||||||
|
samples = pipe.predict(
|
||||||
|
context=context,
|
||||||
|
prediction_length=horizon,
|
||||||
|
num_samples=num_samples,
|
||||||
|
)
|
||||||
|
# samples: (1, num_samples, prediction_length)
|
||||||
|
arr = samples[0].cpu().float().numpy()
|
||||||
|
q10 = np.quantile(arr, 0.10, axis=0).tolist()
|
||||||
|
median = np.quantile(arr, 0.50, axis=0).tolist()
|
||||||
|
q90 = np.quantile(arr, 0.90, axis=0).tolist()
|
||||||
|
return ChronosForecast(
|
||||||
|
horizon=horizon,
|
||||||
|
median=[float(x) for x in median],
|
||||||
|
q10=[float(x) for x in q10],
|
||||||
|
q90=[float(x) for x in q90],
|
||||||
|
samples=[[float(x) for x in row] for row in arr.tolist()],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ping() -> dict[str, object]:
|
||||||
|
try:
|
||||||
|
_load()
|
||||||
|
return {"status": "ok", "model": MODEL_NAME, "device": _state["device"]}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return {"status": "failed", "model": MODEL_NAME, "error": str(exc)}
|
||||||
223
backend/app/models/features.py
Normal file
223
backend/app/models/features.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""모델 학습/추론용 피처 빌더.
|
||||||
|
|
||||||
|
종목 1개 + 룩백 기간을 받아 (date 단위) DataFrame 반환:
|
||||||
|
- OHLCV
|
||||||
|
- returns r1
|
||||||
|
- TA: rsi14, macd, macd_signal, atr14, bb_pct, sma20, ema12, vol_z20
|
||||||
|
- trading_value: foreign_net, institution_net, individual_net (정규화 X, scale 그대로)
|
||||||
|
- macro 정렬: kospi, kosdaq, usdkrw, us10y, kospi_r1, usdkrw_r1
|
||||||
|
- sentiment (v_sentiment_daily): mean_score, weighted_score, n_articles,
|
||||||
|
pos_minus_neg = pos_ratio - neg_ratio. 3일 롤링 mean 도 추가.
|
||||||
|
|
||||||
|
학습 타깃 (build_features 에서만 생성):
|
||||||
|
- y_close_h{1,3,5}: close.shift(-H)
|
||||||
|
- y_ret_h{1,3,5}: y_close_h / close - 1
|
||||||
|
- y_dir_h{1,3,5}: sign(y_ret_h) (1=up, -1=down, 0=flat ±0.3% 이내)
|
||||||
|
|
||||||
|
inference 용 build_features 는 dropna 안 함. 학습용 build_training_frame 은 dropna.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FLAT_BAND = 0.003 # ±0.3% 이내는 flat
|
||||||
|
HORIZONS_DEFAULT = (1, 3, 5)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureFrame:
|
||||||
|
code: str
|
||||||
|
df: pd.DataFrame
|
||||||
|
target_horizons: tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_ohlcv(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, open, high, low, close, volume
|
||||||
|
FROM ohlcv_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date", "open", "high", "low", "close", "volume"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "open", "high", "low", "close", "volume"])
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _load_trading(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, foreign_net, institution_net, individual_net
|
||||||
|
FROM trading_value_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date", "foreign_net", "institution_net", "individual_net"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "foreign_net", "institution_net", "individual_net"])
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _load_macro(start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"SELECT date, key, value FROM macro_daily "
|
||||||
|
"WHERE date BETWEEN :s AND :e ORDER BY date"
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "key", "value"])
|
||||||
|
pivot = df.pivot_table(index="date", columns="key", values="value", aggfunc="last").reset_index()
|
||||||
|
pivot["date"] = pd.to_datetime(pivot["date"]).dt.date
|
||||||
|
pivot.columns.name = None
|
||||||
|
return pivot
|
||||||
|
|
||||||
|
|
||||||
|
def _load_sentiment(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, n_articles, mean_score, pos_ratio, neg_ratio,
|
||||||
|
weighted_score
|
||||||
|
FROM v_sentiment_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
cols = ["date", "n_articles", "mean_score", "pos_ratio", "neg_ratio", "weighted_score"]
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=cols)
|
||||||
|
df = pd.DataFrame(rows, columns=cols)
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
df["pos_minus_neg"] = df["pos_ratio"].fillna(0) - df["neg_ratio"].fillna(0)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _add_ta(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""ta 패키지로 기술 지표 추가."""
|
||||||
|
from ta.momentum import RSIIndicator
|
||||||
|
from ta.trend import EMAIndicator, MACD, SMAIndicator
|
||||||
|
from ta.volatility import AverageTrueRange, BollingerBands
|
||||||
|
|
||||||
|
close = df["close"].astype(float)
|
||||||
|
high = df["high"].astype(float)
|
||||||
|
low = df["low"].astype(float)
|
||||||
|
vol = df["volume"].astype(float)
|
||||||
|
|
||||||
|
df["r1"] = close.pct_change()
|
||||||
|
df["rsi14"] = RSIIndicator(close=close, window=14, fillna=False).rsi()
|
||||||
|
macd = MACD(close=close, window_slow=26, window_fast=12, window_sign=9, fillna=False)
|
||||||
|
df["macd"] = macd.macd()
|
||||||
|
df["macd_signal"] = macd.macd_signal()
|
||||||
|
df["atr14"] = AverageTrueRange(high=high, low=low, close=close, window=14, fillna=False).average_true_range()
|
||||||
|
bb = BollingerBands(close=close, window=20, window_dev=2, fillna=False)
|
||||||
|
df["bb_pct"] = bb.bollinger_pband()
|
||||||
|
df["sma20"] = SMAIndicator(close=close, window=20, fillna=False).sma_indicator()
|
||||||
|
df["ema12"] = EMAIndicator(close=close, window=12, fillna=False).ema_indicator()
|
||||||
|
vol_mean = vol.rolling(20).mean()
|
||||||
|
vol_std = vol.rolling(20).std().replace(0, np.nan)
|
||||||
|
df["vol_z20"] = (vol - vol_mean) / vol_std
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _add_targets(df: pd.DataFrame, horizons: tuple[int, ...]) -> pd.DataFrame:
|
||||||
|
close = df["close"].astype(float)
|
||||||
|
for h in horizons:
|
||||||
|
df[f"y_close_h{h}"] = close.shift(-h)
|
||||||
|
df[f"y_ret_h{h}"] = df[f"y_close_h{h}"] / close - 1.0
|
||||||
|
df[f"y_dir_h{h}"] = np.where(
|
||||||
|
df[f"y_ret_h{h}"] > FLAT_BAND, 1,
|
||||||
|
np.where(df[f"y_ret_h{h}"] < -FLAT_BAND, -1, 0),
|
||||||
|
)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def build_features(
|
||||||
|
code: str,
|
||||||
|
*,
|
||||||
|
lookback_days: int = 365 * 2,
|
||||||
|
end_date: date | None = None,
|
||||||
|
horizons: tuple[int, ...] = HORIZONS_DEFAULT,
|
||||||
|
with_targets: bool = False,
|
||||||
|
) -> FeatureFrame:
|
||||||
|
"""code 1개 종목의 피처 DataFrame 생성.
|
||||||
|
|
||||||
|
inference: with_targets=False 로 호출 → 최신 row 의 피처만 LGBM/Chronos 에 투입.
|
||||||
|
training : with_targets=True 로 호출 → tail H 행은 타깃 NaN → dropna 로 제거.
|
||||||
|
"""
|
||||||
|
end = end_date or date.today()
|
||||||
|
start = end - timedelta(days=lookback_days)
|
||||||
|
|
||||||
|
ohlcv = _load_ohlcv(code, start, end)
|
||||||
|
if ohlcv.empty:
|
||||||
|
return FeatureFrame(code=code, df=ohlcv, target_horizons=horizons)
|
||||||
|
|
||||||
|
df = ohlcv.copy().sort_values("date").reset_index(drop=True)
|
||||||
|
|
||||||
|
df = _add_ta(df)
|
||||||
|
|
||||||
|
trading = _load_trading(code, start, end)
|
||||||
|
if not trading.empty:
|
||||||
|
df = df.merge(trading, on="date", how="left")
|
||||||
|
else:
|
||||||
|
for col in ("foreign_net", "institution_net", "individual_net"):
|
||||||
|
df[col] = np.nan
|
||||||
|
|
||||||
|
macro = _load_macro(start, end)
|
||||||
|
if not macro.empty:
|
||||||
|
df = df.merge(macro, on="date", how="left")
|
||||||
|
for k in ("kospi", "kosdaq", "usdkrw", "us10y"):
|
||||||
|
if k in df.columns:
|
||||||
|
df[f"{k}_r1"] = df[k].pct_change()
|
||||||
|
|
||||||
|
sentiment = _load_sentiment(code, start, end)
|
||||||
|
if not sentiment.empty:
|
||||||
|
df = df.merge(sentiment, on="date", how="left")
|
||||||
|
# 3일 롤링 평균
|
||||||
|
for col in ("mean_score", "weighted_score", "pos_minus_neg", "n_articles"):
|
||||||
|
if col in df.columns:
|
||||||
|
df[f"{col}_3d"] = df[col].rolling(3, min_periods=1).mean()
|
||||||
|
else:
|
||||||
|
for col in ("n_articles", "mean_score", "pos_ratio", "neg_ratio",
|
||||||
|
"weighted_score", "pos_minus_neg"):
|
||||||
|
df[col] = np.nan
|
||||||
|
|
||||||
|
if with_targets:
|
||||||
|
df = _add_targets(df, horizons)
|
||||||
|
|
||||||
|
return FeatureFrame(code=code, df=df, target_horizons=horizons)
|
||||||
|
|
||||||
|
|
||||||
|
def feature_columns(df: pd.DataFrame) -> list[str]:
|
||||||
|
"""LGBM 학습/추론용 피처 컬럼 목록. date / OHLCV / y_* 제외."""
|
||||||
|
drop = {"date", "open", "high", "low", "close", "volume"}
|
||||||
|
cols = [
|
||||||
|
c for c in df.columns
|
||||||
|
if c not in drop and not c.startswith("y_")
|
||||||
|
]
|
||||||
|
return cols
|
||||||
@@ -31,9 +31,12 @@ dependencies = [
|
|||||||
"transformers==4.41.2",
|
"transformers==4.41.2",
|
||||||
"tokenizers==0.19.1",
|
"tokenizers==0.19.1",
|
||||||
"sentencepiece==0.2.0",
|
"sentencepiece==0.2.0",
|
||||||
|
"accelerate==0.30.1",
|
||||||
|
"chronos-forecasting==1.4.1",
|
||||||
"scikit-learn==1.5.0",
|
"scikit-learn==1.5.0",
|
||||||
"lightgbm==4.3.0",
|
"lightgbm==4.3.0",
|
||||||
"ta==0.11.0",
|
"ta==0.11.0",
|
||||||
|
"joblib==1.4.2",
|
||||||
|
|
||||||
# scheduler
|
# scheduler
|
||||||
"apscheduler==3.10.4",
|
"apscheduler==3.10.4",
|
||||||
|
|||||||
Reference in New Issue
Block a user