From b1ca6ab5d31b6775a1244946e957305276dabca5 Mon Sep 17 00:00:00 2001 From: tkrmagid Date: Wed, 20 May 2026 15:59:14 +0900 Subject: [PATCH] =?UTF-8?q?feat(phase-3):=20Chronos=20zero-shot=20?= =?UTF-8?q?=EC=98=88=EC=B8=A1=20+=20=ED=94=BC=EC=B2=98=20=EB=B9=8C?= =?UTF-8?q?=EB=8D=94?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - backend/app/models/chronos.py: amazon/chronos-t5-small (env CHRONOS_MODEL override 가능). lazy singleton, cuda + bf16 자동, q10/median/q90 + raw samples 반환 (앙상블 가중평균용). - backend/app/models/features.py: 종목별 학습/추론 피처 DataFrame. OHLCV + TA(rsi/macd/atr/bb/sma/ema/vol_z) + 외인기관거래대금 + macro (kospi/kosdaq/usdkrw/us10y + r1) + sentiment(v_sentiment_daily, 3d rolling). 학습용은 with_targets=True 로 y_close_h{1,3,5}, y_ret_h*, y_dir_h* (±0.3% flat band) 추가. - pyproject.toml: chronos-forecasting 1.4.1, accelerate 0.30.1, joblib 1.4.2. 이 단계까지는 코드만. 실제 모델 다운로드는 첫 ping/predict 호출 시점에. Co-Authored-By: Claude Opus 4.7 --- backend/app/models/chronos.py | 118 +++++++++++++++++ backend/app/models/features.py | 223 +++++++++++++++++++++++++++++++++ backend/pyproject.toml | 3 + 3 files changed, 344 insertions(+) create mode 100644 backend/app/models/chronos.py create mode 100644 backend/app/models/features.py diff --git a/backend/app/models/chronos.py b/backend/app/models/chronos.py new file mode 100644 index 0000000..7f660f6 --- /dev/null +++ b/backend/app/models/chronos.py @@ -0,0 +1,118 @@ +"""Chronos zero-shot 시계열 예측 어댑터. + +모델: amazon/chronos-t5-small (46M, 빠르고 RTX 3070 Ti 에 충분히 들어감). +환경변수 CHRONOS_MODEL 로 base/large 로 바꿀 수 있음. + +입력: 종가 시계열 (list[float], 최소 32 step). +출력: horizon 일 quantile forecast (q10/median/q90). + +lazy singleton 으로 첫 호출 시 모델 로드. 디바이스는 settings.model_device 따라. +""" +from __future__ import annotations + +import logging +import os +import threading +from dataclasses import dataclass + +from app.config import settings + +logger = logging.getLogger(__name__) + +MODEL_NAME = os.environ.get("CHRONOS_MODEL", "amazon/chronos-t5-small") + +_lock = threading.Lock() +_state: dict[str, object] = {"loaded": False, "pipe": None, "device": None} + + +@dataclass +class ChronosForecast: + horizon: int + median: list[float] + q10: list[float] + q90: list[float] + samples: list[list[float]] # raw samples for ensemble downstream + + +def _resolve_device() -> str: + import torch # lazy + + pref = (settings.model_device or "auto").lower() + if pref == "cuda": + return "cuda" if torch.cuda.is_available() else "cpu" + if pref == "cpu": + return "cpu" + return "cuda" if torch.cuda.is_available() else "cpu" + + +def _load() -> None: + global _state + with _lock: + if _state["loaded"]: + return + import torch + from chronos import ChronosPipeline + + token = settings.huggingface_token or None + if token: + os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token) + os.environ.setdefault("HF_TOKEN", token) + + device = _resolve_device() + # bf16 은 RTX 30xx 이상에서 지원. cpu 에선 fp32. + dtype = torch.bfloat16 if device == "cuda" else torch.float32 + logger.info("loading Chronos %s on %s (dtype=%s)", MODEL_NAME, device, dtype) + pipe = ChronosPipeline.from_pretrained( + MODEL_NAME, + device_map=device, + torch_dtype=dtype, + ) + _state.update({"loaded": True, "pipe": pipe, "device": device}) + + +def forecast( + series: list[float], + *, + horizon: int = 5, + num_samples: int = 30, +) -> ChronosForecast: + """series 의 마지막 시점 이후 horizon 일 예측. + + series 는 일봉 종가. 최소 32개 권장 (그보다 짧으면 Chronos 분위 안정성 떨어짐). + """ + if len(series) < 32: + raise ValueError( + f"series too short ({len(series)}) for Chronos forecast (need >=32)" + ) + _load() + import numpy as np + import torch + + pipe = _state["pipe"] + context = torch.tensor([float(x) for x in series], dtype=torch.float32) + with torch.no_grad(): + samples = pipe.predict( + context=context, + prediction_length=horizon, + num_samples=num_samples, + ) + # samples: (1, num_samples, prediction_length) + arr = samples[0].cpu().float().numpy() + q10 = np.quantile(arr, 0.10, axis=0).tolist() + median = np.quantile(arr, 0.50, axis=0).tolist() + q90 = np.quantile(arr, 0.90, axis=0).tolist() + return ChronosForecast( + horizon=horizon, + median=[float(x) for x in median], + q10=[float(x) for x in q10], + q90=[float(x) for x in q90], + samples=[[float(x) for x in row] for row in arr.tolist()], + ) + + +def ping() -> dict[str, object]: + try: + _load() + return {"status": "ok", "model": MODEL_NAME, "device": _state["device"]} + except Exception as exc: # noqa: BLE001 + return {"status": "failed", "model": MODEL_NAME, "error": str(exc)} diff --git a/backend/app/models/features.py b/backend/app/models/features.py new file mode 100644 index 0000000..e000b05 --- /dev/null +++ b/backend/app/models/features.py @@ -0,0 +1,223 @@ +"""모델 학습/추론용 피처 빌더. + +종목 1개 + 룩백 기간을 받아 (date 단위) DataFrame 반환: + - OHLCV + - returns r1 + - TA: rsi14, macd, macd_signal, atr14, bb_pct, sma20, ema12, vol_z20 + - trading_value: foreign_net, institution_net, individual_net (정규화 X, scale 그대로) + - macro 정렬: kospi, kosdaq, usdkrw, us10y, kospi_r1, usdkrw_r1 + - sentiment (v_sentiment_daily): mean_score, weighted_score, n_articles, + pos_minus_neg = pos_ratio - neg_ratio. 3일 롤링 mean 도 추가. + +학습 타깃 (build_features 에서만 생성): + - y_close_h{1,3,5}: close.shift(-H) + - y_ret_h{1,3,5}: y_close_h / close - 1 + - y_dir_h{1,3,5}: sign(y_ret_h) (1=up, -1=down, 0=flat ±0.3% 이내) + +inference 용 build_features 는 dropna 안 함. 학습용 build_training_frame 은 dropna. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from datetime import date, timedelta + +import numpy as np +import pandas as pd +from sqlalchemy import text + +from app.db.connection import get_engine + +logger = logging.getLogger(__name__) + +FLAT_BAND = 0.003 # ±0.3% 이내는 flat +HORIZONS_DEFAULT = (1, 3, 5) + + +@dataclass +class FeatureFrame: + code: str + df: pd.DataFrame + target_horizons: tuple[int, ...] + + +def _load_ohlcv(code: str, start: date, end: date) -> pd.DataFrame: + eng = get_engine() + sql = text( + """ + SELECT date, open, high, low, close, volume + FROM ohlcv_daily + WHERE code = :code AND date BETWEEN :s AND :e + ORDER BY date + """ + ) + with eng.connect() as conn: + rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all() + if not rows: + return pd.DataFrame(columns=["date", "open", "high", "low", "close", "volume"]) + df = pd.DataFrame(rows, columns=["date", "open", "high", "low", "close", "volume"]) + df["date"] = pd.to_datetime(df["date"]).dt.date + return df + + +def _load_trading(code: str, start: date, end: date) -> pd.DataFrame: + eng = get_engine() + sql = text( + """ + SELECT date, foreign_net, institution_net, individual_net + FROM trading_value_daily + WHERE code = :code AND date BETWEEN :s AND :e + ORDER BY date + """ + ) + with eng.connect() as conn: + rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all() + if not rows: + return pd.DataFrame(columns=["date", "foreign_net", "institution_net", "individual_net"]) + df = pd.DataFrame(rows, columns=["date", "foreign_net", "institution_net", "individual_net"]) + df["date"] = pd.to_datetime(df["date"]).dt.date + return df + + +def _load_macro(start: date, end: date) -> pd.DataFrame: + eng = get_engine() + sql = text( + "SELECT date, key, value FROM macro_daily " + "WHERE date BETWEEN :s AND :e ORDER BY date" + ) + with eng.connect() as conn: + rows = conn.execute(sql, {"s": start, "e": end}).all() + if not rows: + return pd.DataFrame(columns=["date"]) + df = pd.DataFrame(rows, columns=["date", "key", "value"]) + pivot = df.pivot_table(index="date", columns="key", values="value", aggfunc="last").reset_index() + pivot["date"] = pd.to_datetime(pivot["date"]).dt.date + pivot.columns.name = None + return pivot + + +def _load_sentiment(code: str, start: date, end: date) -> pd.DataFrame: + eng = get_engine() + sql = text( + """ + SELECT date, n_articles, mean_score, pos_ratio, neg_ratio, + weighted_score + FROM v_sentiment_daily + WHERE code = :code AND date BETWEEN :s AND :e + ORDER BY date + """ + ) + with eng.connect() as conn: + rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all() + cols = ["date", "n_articles", "mean_score", "pos_ratio", "neg_ratio", "weighted_score"] + if not rows: + return pd.DataFrame(columns=cols) + df = pd.DataFrame(rows, columns=cols) + df["date"] = pd.to_datetime(df["date"]).dt.date + df["pos_minus_neg"] = df["pos_ratio"].fillna(0) - df["neg_ratio"].fillna(0) + return df + + +def _add_ta(df: pd.DataFrame) -> pd.DataFrame: + """ta 패키지로 기술 지표 추가.""" + from ta.momentum import RSIIndicator + from ta.trend import EMAIndicator, MACD, SMAIndicator + from ta.volatility import AverageTrueRange, BollingerBands + + close = df["close"].astype(float) + high = df["high"].astype(float) + low = df["low"].astype(float) + vol = df["volume"].astype(float) + + df["r1"] = close.pct_change() + df["rsi14"] = RSIIndicator(close=close, window=14, fillna=False).rsi() + macd = MACD(close=close, window_slow=26, window_fast=12, window_sign=9, fillna=False) + df["macd"] = macd.macd() + df["macd_signal"] = macd.macd_signal() + df["atr14"] = AverageTrueRange(high=high, low=low, close=close, window=14, fillna=False).average_true_range() + bb = BollingerBands(close=close, window=20, window_dev=2, fillna=False) + df["bb_pct"] = bb.bollinger_pband() + df["sma20"] = SMAIndicator(close=close, window=20, fillna=False).sma_indicator() + df["ema12"] = EMAIndicator(close=close, window=12, fillna=False).ema_indicator() + vol_mean = vol.rolling(20).mean() + vol_std = vol.rolling(20).std().replace(0, np.nan) + df["vol_z20"] = (vol - vol_mean) / vol_std + return df + + +def _add_targets(df: pd.DataFrame, horizons: tuple[int, ...]) -> pd.DataFrame: + close = df["close"].astype(float) + for h in horizons: + df[f"y_close_h{h}"] = close.shift(-h) + df[f"y_ret_h{h}"] = df[f"y_close_h{h}"] / close - 1.0 + df[f"y_dir_h{h}"] = np.where( + df[f"y_ret_h{h}"] > FLAT_BAND, 1, + np.where(df[f"y_ret_h{h}"] < -FLAT_BAND, -1, 0), + ) + return df + + +def build_features( + code: str, + *, + lookback_days: int = 365 * 2, + end_date: date | None = None, + horizons: tuple[int, ...] = HORIZONS_DEFAULT, + with_targets: bool = False, +) -> FeatureFrame: + """code 1개 종목의 피처 DataFrame 생성. + + inference: with_targets=False 로 호출 → 최신 row 의 피처만 LGBM/Chronos 에 투입. + training : with_targets=True 로 호출 → tail H 행은 타깃 NaN → dropna 로 제거. + """ + end = end_date or date.today() + start = end - timedelta(days=lookback_days) + + ohlcv = _load_ohlcv(code, start, end) + if ohlcv.empty: + return FeatureFrame(code=code, df=ohlcv, target_horizons=horizons) + + df = ohlcv.copy().sort_values("date").reset_index(drop=True) + + df = _add_ta(df) + + trading = _load_trading(code, start, end) + if not trading.empty: + df = df.merge(trading, on="date", how="left") + else: + for col in ("foreign_net", "institution_net", "individual_net"): + df[col] = np.nan + + macro = _load_macro(start, end) + if not macro.empty: + df = df.merge(macro, on="date", how="left") + for k in ("kospi", "kosdaq", "usdkrw", "us10y"): + if k in df.columns: + df[f"{k}_r1"] = df[k].pct_change() + + sentiment = _load_sentiment(code, start, end) + if not sentiment.empty: + df = df.merge(sentiment, on="date", how="left") + # 3일 롤링 평균 + for col in ("mean_score", "weighted_score", "pos_minus_neg", "n_articles"): + if col in df.columns: + df[f"{col}_3d"] = df[col].rolling(3, min_periods=1).mean() + else: + for col in ("n_articles", "mean_score", "pos_ratio", "neg_ratio", + "weighted_score", "pos_minus_neg"): + df[col] = np.nan + + if with_targets: + df = _add_targets(df, horizons) + + return FeatureFrame(code=code, df=df, target_horizons=horizons) + + +def feature_columns(df: pd.DataFrame) -> list[str]: + """LGBM 학습/추론용 피처 컬럼 목록. date / OHLCV / y_* 제외.""" + drop = {"date", "open", "high", "low", "close", "volume"} + cols = [ + c for c in df.columns + if c not in drop and not c.startswith("y_") + ] + return cols diff --git a/backend/pyproject.toml b/backend/pyproject.toml index e4c7a35..8a97bca 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -31,9 +31,12 @@ dependencies = [ "transformers==4.41.2", "tokenizers==0.19.1", "sentencepiece==0.2.0", + "accelerate==0.30.1", + "chronos-forecasting==1.4.1", "scikit-learn==1.5.0", "lightgbm==4.3.0", "ta==0.11.0", + "joblib==1.4.2", # scheduler "apscheduler==3.10.4",