feat(phase-4): LGBM 모델 + 앙상블 + 매칭/재학습 잡
- backend/app/models/lgbm.py: 종목 × horizon 별 LightGBM 회귀(y_ret_h)
+ 다중분류(y_dir_h, 3-class). joblib 으로 backend/data/models/{code}_h{H}_*.pkl
저장. early_stopping(30). predict_one() 으로 최신 영업일 피처에 추론.
- backend/app/models/weights.py: ensemble_weights 테이블 IO,
default w_chronos=0.6 / w_lgbm=0.4 (DB 행 없으면 fallback).
- backend/app/models/ensemble.py: Chronos sample 분포 + LGBM regression+cls
결합. point/q10/q90 + prob_up/flat/down + direction 라벨. 한쪽 모델
실패 시 다른 쪽 단독 fallback (cold start: chronos 단독).
- backend/app/pipelines/predict_one.py: predict_and_store(). 결과를
predictions 테이블에 UPSERT, user_triggered 누적 OR. base_date = 마지막
ohlcv 거래일, target_date = base_date + H 영업일(주말 스킵, 공휴일은
매칭잡에서 자연 보정).
- backend/app/pipelines/match_outcomes.py: target_date == d 인
user_triggered=TRUE 예측을 d 의 실제 종가와 매칭 → prediction_outcomes
적재. direction_hit(±0.3% flat band) + abs_error. 실제 종가 없으면
자연 skip.
- backend/app/pipelines/retrain_weekly.py: 시드 10종목 × H 재학습 +
최근 30일 model_performance 적재.
- backend/app/db/migrations/003_ensemble_weights.sql: (code, horizon) →
(w_chronos, w_lgbm, hit_rate_*, sample_count).
- backend/app/pipelines/scheduler.py:
daily_batch : 평일 16:00 KST
match_outcomes : 평일 16:30 KST ← 사용자가 확정한 매칭 시점
retrain_weekly : 일요일 02:00 KST
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
19
backend/app/db/migrations/003_ensemble_weights.sql
Normal file
19
backend/app/db/migrations/003_ensemble_weights.sql
Normal file
@@ -0,0 +1,19 @@
|
||||
-- Phase 4: 앙상블 가중치 저장.
|
||||
-- (code, horizon) 별로 Chronos vs LGBM 가중치. 일요일 02:00 재학습 잡에서 갱신.
|
||||
|
||||
\set ON_ERROR_STOP on
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ensemble_weights (
|
||||
code TEXT NOT NULL REFERENCES symbols(code),
|
||||
horizon INT NOT NULL,
|
||||
w_chronos REAL NOT NULL DEFAULT 0.6,
|
||||
w_lgbm REAL NOT NULL DEFAULT 0.4,
|
||||
hit_rate_chronos REAL,
|
||||
hit_rate_lgbm REAL,
|
||||
sample_count INT,
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
PRIMARY KEY (code, horizon)
|
||||
);
|
||||
|
||||
COMMENT ON TABLE ensemble_weights IS
|
||||
'Phase 4: (code, horizon) 별 Chronos/LGBM 가중치. 최근 30일 prediction_outcomes hit_rate 기반 매주 갱신.';
|
||||
174
backend/app/models/ensemble.py
Normal file
174
backend/app/models/ensemble.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Chronos + LGBM 앙상블 추론.
|
||||
|
||||
final_price[h] = w_c * chronos.median[h-1] + w_l * lgbm.predicted_close
|
||||
final_q10[h] = w_c * chronos.q10[h-1] + w_l * lgbm.predicted_close * 0.97
|
||||
final_q90[h] = w_c * chronos.q90[h-1] + w_l * lgbm.predicted_close * 1.03
|
||||
|
||||
LGBM 은 단일 horizon 의 다음 종가(point) 만 주므로, 그 자체로는 신뢰구간이 없음.
|
||||
근사로 ±3% band 를 LGBM 의 q10/q90 자리에 사용. Chronos 의 sample 분포가
|
||||
주된 신뢰구간 정보 (Chronos 우세하면 ci 가 좁아짐).
|
||||
|
||||
direction 확률:
|
||||
- LGBM 분류기에서 prob_up/flat/down (3-class) 그대로
|
||||
- Chronos 는 next-day return 부호 비율: samples.shift1 / base_close - 1 의 부호
|
||||
- 둘을 같은 가중치로 평균
|
||||
|
||||
LGBM 모델이 없으면 Chronos 단독으로 진행 (cold start).
|
||||
Chronos 도 실패하면 LGBM 단독으로 진행.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.models.chronos import ChronosForecast
|
||||
from app.models.chronos import forecast as chronos_forecast
|
||||
from app.models.lgbm import LgbmForecast
|
||||
from app.models.lgbm import predict_one as lgbm_predict
|
||||
from app.models.weights import load_weights
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnsembleStep:
|
||||
horizon: int # 1..H 거래일 후
|
||||
target_idx: int # chronos median 의 0-based 인덱스 (horizon-1)
|
||||
point_close: float
|
||||
ci_low: float
|
||||
ci_high: float
|
||||
prob_up: float
|
||||
prob_flat: float
|
||||
prob_down: float
|
||||
direction: str # 'up' / 'flat' / 'down'
|
||||
expected_return: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnsemblePrediction:
|
||||
code: str
|
||||
base_close: float
|
||||
horizons: list[int]
|
||||
steps: list[EnsembleStep]
|
||||
sources_used: list[str]
|
||||
|
||||
|
||||
def _chronos_direction(samples: list[list[float]], base_close: float, horizon: int) -> tuple[float, float, float]:
|
||||
"""Chronos sample 분포에서 (prob_up, prob_flat, prob_down). ±0.3% flat band."""
|
||||
if not samples:
|
||||
return 0.33, 0.34, 0.33
|
||||
arr = np.array(samples)[:, horizon - 1] # 해당 step 의 sample 값
|
||||
ret = arr / base_close - 1.0
|
||||
p_up = float((ret > 0.003).mean())
|
||||
p_dn = float((ret < -0.003).mean())
|
||||
p_fl = 1.0 - p_up - p_dn
|
||||
return p_up, p_fl, p_dn
|
||||
|
||||
|
||||
def predict(code: str, *, horizons: tuple[int, ...] = (1, 3, 5)) -> EnsemblePrediction:
|
||||
"""한 종목에 대해 horizons 별 앙상블 예측. on-demand 추론용."""
|
||||
max_h = max(horizons)
|
||||
|
||||
# Chronos: 종가 시계열 가져와서 max_h 까지 예측.
|
||||
from app.models.features import build_features # local import
|
||||
|
||||
ff = build_features(code, lookback_days=400, horizons=horizons, with_targets=False)
|
||||
df = ff.df
|
||||
if df.empty:
|
||||
raise RuntimeError(f"no OHLCV data for {code}")
|
||||
closes = df["close"].astype(float).tolist()
|
||||
base_close = float(closes[-1])
|
||||
|
||||
sources_used: list[str] = []
|
||||
cf: ChronosForecast | None = None
|
||||
try:
|
||||
cf = chronos_forecast(closes, horizon=max_h, num_samples=30)
|
||||
sources_used.append("chronos")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("chronos forecast failed for %s: %s", code, exc)
|
||||
|
||||
steps: list[EnsembleStep] = []
|
||||
for h in horizons:
|
||||
lf: LgbmForecast | None = None
|
||||
try:
|
||||
lf = lgbm_predict(code, h)
|
||||
if lf is not None:
|
||||
sources_used.append(f"lgbm_h{h}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("lgbm predict failed for %s h=%d: %s", code, h, exc)
|
||||
|
||||
# 가중치 (DB 없으면 default 0.6/0.4).
|
||||
w = load_weights(code, h)
|
||||
wc, wl = w.w_chronos, w.w_lgbm
|
||||
# 한쪽이 없으면 다른 쪽 전부.
|
||||
if cf is None and lf is None:
|
||||
raise RuntimeError(f"both chronos & lgbm failed for {code} h={h}")
|
||||
if cf is None:
|
||||
wc, wl = 0.0, 1.0
|
||||
if lf is None:
|
||||
wc, wl = 1.0, 0.0
|
||||
|
||||
if cf is not None:
|
||||
c_med = cf.median[h - 1]
|
||||
c_q10 = cf.q10[h - 1]
|
||||
c_q90 = cf.q90[h - 1]
|
||||
else:
|
||||
c_med = c_q10 = c_q90 = base_close # not used (wc=0)
|
||||
|
||||
if lf is not None:
|
||||
l_close = lf.predicted_close
|
||||
l_lo = l_close * 0.97
|
||||
l_hi = l_close * 1.03
|
||||
l_pu, l_pf, l_pd = lf.prob_up, lf.prob_flat, lf.prob_down
|
||||
else:
|
||||
l_close = l_lo = l_hi = base_close
|
||||
l_pu = l_pf = l_pd = 0.0
|
||||
|
||||
point = wc * c_med + wl * l_close
|
||||
lo = wc * c_q10 + wl * l_lo
|
||||
hi = wc * c_q90 + wl * l_hi
|
||||
|
||||
if cf is not None:
|
||||
cp_up, cp_fl, cp_dn = _chronos_direction(cf.samples, base_close, h)
|
||||
else:
|
||||
cp_up = cp_fl = cp_dn = 0.0
|
||||
|
||||
# direction prob: source 마다 weights 동일하게 가중평균
|
||||
if lf is not None and cf is not None:
|
||||
p_up = 0.5 * cp_up + 0.5 * l_pu
|
||||
p_fl = 0.5 * cp_fl + 0.5 * l_pf
|
||||
p_dn = 0.5 * cp_dn + 0.5 * l_pd
|
||||
elif cf is not None:
|
||||
p_up, p_fl, p_dn = cp_up, cp_fl, cp_dn
|
||||
else:
|
||||
p_up, p_fl, p_dn = l_pu, l_pf, l_pd
|
||||
|
||||
# 정규화 (혹시 합이 0 가 아닐 때)
|
||||
s = max(p_up + p_fl + p_dn, 1e-9)
|
||||
p_up, p_fl, p_dn = p_up / s, p_fl / s, p_dn / s
|
||||
dir_lbl = "up" if p_up >= max(p_fl, p_dn) else ("down" if p_dn >= p_fl else "flat")
|
||||
|
||||
steps.append(
|
||||
EnsembleStep(
|
||||
horizon=h,
|
||||
target_idx=h - 1,
|
||||
point_close=float(point),
|
||||
ci_low=float(lo),
|
||||
ci_high=float(hi),
|
||||
prob_up=float(p_up),
|
||||
prob_flat=float(p_fl),
|
||||
prob_down=float(p_dn),
|
||||
direction=dir_lbl,
|
||||
expected_return=float(point / base_close - 1.0),
|
||||
)
|
||||
)
|
||||
|
||||
return EnsemblePrediction(
|
||||
code=code,
|
||||
base_close=base_close,
|
||||
horizons=list(horizons),
|
||||
steps=steps,
|
||||
sources_used=sources_used,
|
||||
)
|
||||
180
backend/app/models/lgbm.py
Normal file
180
backend/app/models/lgbm.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""LightGBM 회귀 + 분류 모델. 종목 × horizon 별 별도 저장.
|
||||
|
||||
- 회귀: target = y_ret_h{H}. 예측 후 base_close*(1+pred) 로 가격 환산.
|
||||
- 분류: target = y_dir_h{H} ∈ {-1, 0, +1}. 3-class softmax 로 prob_up/flat/down.
|
||||
|
||||
저장 경로: backend/data/models/{code}_h{H}_reg.pkl, _cls.pkl (joblib).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from app.models.features import build_features, feature_columns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MODEL_DIR = Path(os.environ.get("LGBM_MODEL_DIR", "/app/data/models"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class LgbmForecast:
|
||||
horizon: int
|
||||
base_close: float
|
||||
predicted_close: float
|
||||
predicted_return: float
|
||||
prob_up: float
|
||||
prob_flat: float
|
||||
prob_down: float
|
||||
|
||||
|
||||
def _model_paths(code: str, horizon: int) -> tuple[Path, Path]:
|
||||
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||
return (
|
||||
MODEL_DIR / f"{code}_h{horizon}_reg.pkl",
|
||||
MODEL_DIR / f"{code}_h{horizon}_cls.pkl",
|
||||
)
|
||||
|
||||
|
||||
def _prepare_xy(code: str, horizon: int, lookback_days: int) -> tuple[pd.DataFrame, pd.Series, pd.Series, list[str]]:
|
||||
ff = build_features(
|
||||
code,
|
||||
lookback_days=lookback_days,
|
||||
horizons=(horizon,),
|
||||
with_targets=True,
|
||||
)
|
||||
df = ff.df
|
||||
if df.empty:
|
||||
return df, pd.Series(dtype=float), pd.Series(dtype=int), []
|
||||
y_ret_col = f"y_ret_h{horizon}"
|
||||
y_dir_col = f"y_dir_h{horizon}"
|
||||
# 타깃 NaN (마지막 H 행) 제거.
|
||||
df = df.dropna(subset=[y_ret_col, y_dir_col])
|
||||
feats = feature_columns(df)
|
||||
if not feats:
|
||||
return df, pd.Series(dtype=float), pd.Series(dtype=int), []
|
||||
X = df[feats]
|
||||
# LightGBM 은 NaN 자체 처리 가능.
|
||||
y_ret = df[y_ret_col].astype(float)
|
||||
y_dir = df[y_dir_col].astype(int)
|
||||
return X, y_ret, y_dir, feats
|
||||
|
||||
|
||||
def train_one(code: str, horizon: int, *, lookback_days: int = 365 * 3) -> dict:
|
||||
"""1종목 × 1 horizon 학습. 저장된 모델 파일 경로 + 샘플 수 반환."""
|
||||
import lightgbm as lgb
|
||||
|
||||
X, y_ret, y_dir, feats = _prepare_xy(code, horizon, lookback_days)
|
||||
if X.empty or len(X) < 100:
|
||||
return {"code": code, "horizon": horizon, "status": "skipped_too_few_rows", "n_rows": int(len(X))}
|
||||
|
||||
reg_params = dict(
|
||||
objective="regression",
|
||||
learning_rate=0.05,
|
||||
num_leaves=31,
|
||||
min_data_in_leaf=20,
|
||||
feature_fraction=0.85,
|
||||
bagging_fraction=0.8,
|
||||
bagging_freq=5,
|
||||
verbose=-1,
|
||||
)
|
||||
cls_params = dict(
|
||||
objective="multiclass",
|
||||
num_class=3,
|
||||
learning_rate=0.05,
|
||||
num_leaves=31,
|
||||
min_data_in_leaf=20,
|
||||
feature_fraction=0.85,
|
||||
bagging_fraction=0.8,
|
||||
bagging_freq=5,
|
||||
verbose=-1,
|
||||
)
|
||||
|
||||
# 분류는 -1/0/1 → 0/1/2 인덱스로 매핑.
|
||||
y_dir_idx = (y_dir + 1).astype(int)
|
||||
|
||||
n = len(X)
|
||||
split = int(n * 0.85)
|
||||
X_tr, X_val = X.iloc[:split], X.iloc[split:]
|
||||
yr_tr, yr_val = y_ret.iloc[:split], y_ret.iloc[split:]
|
||||
yc_tr, yc_val = y_dir_idx.iloc[:split], y_dir_idx.iloc[split:]
|
||||
|
||||
reg_train = lgb.Dataset(X_tr, label=yr_tr)
|
||||
reg_valid = lgb.Dataset(X_val, label=yr_val, reference=reg_train)
|
||||
reg_model = lgb.train(
|
||||
reg_params,
|
||||
reg_train,
|
||||
num_boost_round=400,
|
||||
valid_sets=[reg_valid],
|
||||
callbacks=[lgb.early_stopping(stopping_rounds=30, verbose=False)],
|
||||
)
|
||||
|
||||
cls_train = lgb.Dataset(X_tr, label=yc_tr)
|
||||
cls_valid = lgb.Dataset(X_val, label=yc_val, reference=cls_train)
|
||||
cls_model = lgb.train(
|
||||
cls_params,
|
||||
cls_train,
|
||||
num_boost_round=400,
|
||||
valid_sets=[cls_valid],
|
||||
callbacks=[lgb.early_stopping(stopping_rounds=30, verbose=False)],
|
||||
)
|
||||
|
||||
reg_path, cls_path = _model_paths(code, horizon)
|
||||
joblib.dump({"model": reg_model, "features": feats}, reg_path)
|
||||
joblib.dump({"model": cls_model, "features": feats}, cls_path)
|
||||
|
||||
return {
|
||||
"code": code,
|
||||
"horizon": horizon,
|
||||
"status": "ok",
|
||||
"n_rows": int(len(X)),
|
||||
"reg_best_iter": int(reg_model.best_iteration or 0),
|
||||
"cls_best_iter": int(cls_model.best_iteration or 0),
|
||||
"reg_path": str(reg_path),
|
||||
"cls_path": str(cls_path),
|
||||
}
|
||||
|
||||
|
||||
def predict_one(code: str, horizon: int, *, lookback_days: int = 400) -> LgbmForecast | None:
|
||||
"""1종목 × 1 horizon 추론. 모델 없으면 None.
|
||||
|
||||
가장 최신 영업일 피처를 사용. base_close 는 그 행의 close.
|
||||
"""
|
||||
reg_path, cls_path = _model_paths(code, horizon)
|
||||
if not reg_path.exists() or not cls_path.exists():
|
||||
return None
|
||||
reg_blob = joblib.load(reg_path)
|
||||
cls_blob = joblib.load(cls_path)
|
||||
feats_reg = reg_blob["features"]
|
||||
feats_cls = cls_blob["features"]
|
||||
reg_model = reg_blob["model"]
|
||||
cls_model = cls_blob["model"]
|
||||
|
||||
ff = build_features(code, lookback_days=lookback_days, horizons=(horizon,), with_targets=False)
|
||||
df = ff.df
|
||||
if df.empty:
|
||||
return None
|
||||
last = df.iloc[[-1]]
|
||||
base_close = float(last["close"].iloc[0])
|
||||
# 피처 정렬 (모델이 학습 당시 본 컬럼 순서대로).
|
||||
X_reg = last.reindex(columns=feats_reg).fillna(value=np.nan)
|
||||
X_cls = last.reindex(columns=feats_cls).fillna(value=np.nan)
|
||||
pred_ret = float(reg_model.predict(X_reg)[0])
|
||||
probs = cls_model.predict(X_cls)[0]
|
||||
# 인덱스 0=-1(down), 1=0(flat), 2=+1(up)
|
||||
prob_down, prob_flat, prob_up = float(probs[0]), float(probs[1]), float(probs[2])
|
||||
return LgbmForecast(
|
||||
horizon=horizon,
|
||||
base_close=base_close,
|
||||
predicted_close=base_close * (1.0 + pred_ret),
|
||||
predicted_return=pred_ret,
|
||||
prob_up=prob_up,
|
||||
prob_flat=prob_flat,
|
||||
prob_down=prob_down,
|
||||
)
|
||||
75
backend/app/models/weights.py
Normal file
75
backend/app/models/weights.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""ensemble_weights 테이블 IO. 기본 가중치 (chronos 0.6, lgbm 0.4)."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.connection import get_engine
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnsembleWeights:
|
||||
code: str
|
||||
horizon: int
|
||||
w_chronos: float
|
||||
w_lgbm: float
|
||||
|
||||
|
||||
DEFAULT_W_CHRONOS = 0.6
|
||||
DEFAULT_W_LGBM = 0.4
|
||||
|
||||
|
||||
def load_weights(code: str, horizon: int) -> EnsembleWeights:
|
||||
eng = get_engine()
|
||||
with eng.connect() as conn:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"SELECT w_chronos, w_lgbm FROM ensemble_weights "
|
||||
"WHERE code = :code AND horizon = :h"
|
||||
),
|
||||
{"code": code, "h": horizon},
|
||||
).first()
|
||||
if not row:
|
||||
return EnsembleWeights(code, horizon, DEFAULT_W_CHRONOS, DEFAULT_W_LGBM)
|
||||
return EnsembleWeights(code, horizon, float(row[0]), float(row[1]))
|
||||
|
||||
|
||||
def upsert_weights(
|
||||
code: str,
|
||||
horizon: int,
|
||||
w_chronos: float,
|
||||
w_lgbm: float,
|
||||
*,
|
||||
hit_rate_chronos: float | None = None,
|
||||
hit_rate_lgbm: float | None = None,
|
||||
sample_count: int | None = None,
|
||||
) -> None:
|
||||
eng = get_engine()
|
||||
with eng.begin() as conn:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO ensemble_weights
|
||||
(code, horizon, w_chronos, w_lgbm, hit_rate_chronos, hit_rate_lgbm, sample_count, updated_at)
|
||||
VALUES
|
||||
(:code, :h, :wc, :wl, :hc, :hl, :n, NOW())
|
||||
ON CONFLICT (code, horizon) DO UPDATE SET
|
||||
w_chronos = EXCLUDED.w_chronos,
|
||||
w_lgbm = EXCLUDED.w_lgbm,
|
||||
hit_rate_chronos = EXCLUDED.hit_rate_chronos,
|
||||
hit_rate_lgbm = EXCLUDED.hit_rate_lgbm,
|
||||
sample_count = EXCLUDED.sample_count,
|
||||
updated_at = NOW()
|
||||
"""
|
||||
),
|
||||
{
|
||||
"code": code,
|
||||
"h": horizon,
|
||||
"wc": float(w_chronos),
|
||||
"wl": float(w_lgbm),
|
||||
"hc": hit_rate_chronos,
|
||||
"hl": hit_rate_lgbm,
|
||||
"n": sample_count,
|
||||
},
|
||||
)
|
||||
165
backend/app/pipelines/match_outcomes.py
Normal file
165
backend/app/pipelines/match_outcomes.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""prediction_outcomes 매칭 배치.
|
||||
|
||||
평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의
|
||||
확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, target_date == 오늘인
|
||||
user_triggered=TRUE 예측을 그 종가와 매칭.
|
||||
|
||||
cold-start / 휴장일 대비: 인자로 받은 target_date 의 ohlcv_daily 에 종가가
|
||||
없으면 자연스럽게 skip. 다음 거래일 매칭 잡이 다시 시도하면 그 날짜는
|
||||
여전히 매칭되지 않으므로 (매칭 sql 이 target_date 기준), 영원히 매칭 안되는
|
||||
잘못된 calendar date 예측은 cleanup CLI 로 별도 정리 가능 (Phase 7).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.connection import get_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# direction_hit 판정 시 ±0.3% 이내는 flat. (features 의 FLAT_BAND 와 동일)
|
||||
FLAT_BAND = 0.003
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchSummary:
|
||||
target_date: str
|
||||
candidates: int
|
||||
matched: int
|
||||
skipped_no_actual: int
|
||||
already_resolved: int
|
||||
|
||||
|
||||
def _direction_label(ret: float) -> str:
|
||||
if ret > FLAT_BAND:
|
||||
return "up"
|
||||
if ret < -FLAT_BAND:
|
||||
return "down"
|
||||
return "flat"
|
||||
|
||||
|
||||
def match_for_date(d: date) -> MatchSummary:
|
||||
"""target_date == d 인 user_triggered=TRUE 예측을 매칭."""
|
||||
eng = get_engine()
|
||||
with eng.begin() as conn:
|
||||
# 매칭 대상 예측 + 매칭 안 됐는지 확인.
|
||||
candidate_rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT p.id, p.code, p.base_date, p.horizon, p.point_forecast,
|
||||
p.direction, p.model
|
||||
FROM predictions p
|
||||
LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id
|
||||
WHERE p.target_date = :d
|
||||
AND p.user_triggered = TRUE
|
||||
AND po.prediction_id IS NULL
|
||||
"""
|
||||
),
|
||||
{"d": d},
|
||||
).all()
|
||||
candidates = len(candidate_rows)
|
||||
if not candidates:
|
||||
return MatchSummary(str(d), 0, 0, 0, 0)
|
||||
|
||||
# 종목별로 actual close 조회 (한번에 batch).
|
||||
codes = list({r[1] for r in candidate_rows})
|
||||
actual_map: dict[tuple[str, date], float] = {}
|
||||
for code in codes:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"
|
||||
),
|
||||
{"c": code, "d": d},
|
||||
).first()
|
||||
if row and row[0] is not None:
|
||||
actual_map[(code, d)] = float(row[0])
|
||||
|
||||
# base_close (각 예측의 base_date 종가) 도 필요 — direction 판정용.
|
||||
base_close_map: dict[tuple[str, date], float] = {}
|
||||
for pid, code, base_date, *_ in candidate_rows:
|
||||
key = (code, base_date)
|
||||
if key in base_close_map:
|
||||
continue
|
||||
row = conn.execute(
|
||||
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
||||
{"c": code, "d": base_date},
|
||||
).first()
|
||||
if row and row[0] is not None:
|
||||
base_close_map[key] = float(row[0])
|
||||
|
||||
matched = 0
|
||||
skipped = 0
|
||||
already = 0
|
||||
for pid, code, base_date, horizon, point_forecast, pred_dir, model in candidate_rows:
|
||||
actual = actual_map.get((code, d))
|
||||
base_close = base_close_map.get((code, base_date))
|
||||
if actual is None or base_close is None:
|
||||
skipped += 1
|
||||
continue
|
||||
actual_ret = actual / base_close - 1.0
|
||||
actual_dir = _direction_label(actual_ret)
|
||||
dir_hit = (pred_dir == actual_dir)
|
||||
abs_err = abs(float(point_forecast) - actual) if point_forecast is not None else None
|
||||
|
||||
try:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO prediction_outcomes
|
||||
(prediction_id, code, target_date, horizon, model,
|
||||
predicted_close, actual_close, actual_return, direction_hit, abs_error)
|
||||
VALUES
|
||||
(:pid, :code, :d, :h, :m, :pc, :ac, :ar, :dh, :ae)
|
||||
ON CONFLICT (prediction_id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{
|
||||
"pid": pid,
|
||||
"code": code,
|
||||
"d": d,
|
||||
"h": horizon,
|
||||
"m": model,
|
||||
"pc": float(point_forecast) if point_forecast is not None else None,
|
||||
"ac": actual,
|
||||
"ar": float(actual_ret),
|
||||
"dh": bool(dir_hit),
|
||||
"ae": float(abs_err) if abs_err is not None else None,
|
||||
},
|
||||
)
|
||||
matched += 1
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("match insert failed pid=%s: %s", pid, exc)
|
||||
already += 1
|
||||
|
||||
return MatchSummary(
|
||||
target_date=str(d),
|
||||
candidates=candidates,
|
||||
matched=matched,
|
||||
skipped_no_actual=skipped,
|
||||
already_resolved=already,
|
||||
)
|
||||
|
||||
|
||||
def match_today() -> dict[str, Any]:
|
||||
"""평일 16:30 KST 호출용. target_date == today (KST) 인 행 매칭."""
|
||||
from datetime import datetime, timezone, timedelta as td
|
||||
|
||||
kst = timezone(td(hours=9))
|
||||
today = datetime.now(kst).date()
|
||||
summary = match_for_date(today)
|
||||
return {
|
||||
"today": str(today),
|
||||
"summary": summary.__dict__,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
out = match_today()
|
||||
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|
||||
138
backend/app/pipelines/predict_one.py
Normal file
138
backend/app/pipelines/predict_one.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""On-demand 예측 + DB 적재.
|
||||
|
||||
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
|
||||
- ensemble.predict() 로 horizons (1,3,5) 결과 계산
|
||||
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
|
||||
(대충 calendar 일로 +h * 1.4 — KRX 영업일 추정. Phase 4 단순화: base_date + h 영업일은
|
||||
ohlcv 상의 다음 h 거래일이 아닌, "거래일 카운트" 대신 단순 calendar+h 로 저장하고
|
||||
매칭 잡에서 ohlcv_daily 에 그 날짜 행이 있는지로 자연 보정.)
|
||||
|
||||
대안 정확도 위해: 매칭 잡은 "예측의 target_date 이 오늘"인 행을 그날 종가와 비교.
|
||||
calendar date 가 비거래일이면 매칭이 안 되니, 매칭 잡은 매일 실행되어 모일 때 처리.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import asdict
|
||||
from datetime import date, datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.connection import get_engine
|
||||
from app.models.ensemble import EnsemblePrediction, predict as ensemble_predict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KST = timezone(timedelta(hours=9))
|
||||
|
||||
|
||||
def _next_trading_target(base_date: date, horizon: int) -> date:
|
||||
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
|
||||
d = base_date
|
||||
added = 0
|
||||
while added < horizon:
|
||||
d = d + timedelta(days=1)
|
||||
if d.weekday() < 5: # 0..4 = Mon..Fri
|
||||
added += 1
|
||||
return d
|
||||
|
||||
|
||||
def _last_trading_date(code: str) -> date | None:
|
||||
eng = get_engine()
|
||||
with eng.connect() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT MAX(date) FROM ohlcv_daily WHERE code = :c"),
|
||||
{"c": code},
|
||||
).first()
|
||||
return row[0] if row and row[0] else None
|
||||
|
||||
|
||||
def predict_and_store(
|
||||
code: str,
|
||||
*,
|
||||
horizons: tuple[int, ...] = (1, 3, 5),
|
||||
user_triggered: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""앙상블 예측 실행 + predictions 테이블 적재. 결과 JSON-serializable dict 반환."""
|
||||
base_date = _last_trading_date(code)
|
||||
if base_date is None:
|
||||
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
|
||||
|
||||
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
|
||||
now = datetime.now(KST)
|
||||
|
||||
eng = get_engine()
|
||||
saved_ids: list[int] = []
|
||||
with eng.begin() as conn:
|
||||
for step in pred.steps:
|
||||
target_date = _next_trading_target(base_date, step.horizon)
|
||||
features_snap = {
|
||||
"base_close": pred.base_close,
|
||||
"sources_used": pred.sources_used,
|
||||
"direction": step.direction,
|
||||
}
|
||||
row = conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO predictions
|
||||
(code, predicted_at, base_date, target_date, horizon, model,
|
||||
direction, prob_up, prob_flat, prob_down, expected_return,
|
||||
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
|
||||
VALUES
|
||||
(:code, :predicted_at, :base_date, :target_date, :horizon, 'ensemble',
|
||||
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
|
||||
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
|
||||
ON CONFLICT (code, base_date, target_date, horizon, model)
|
||||
DO UPDATE SET
|
||||
predicted_at = EXCLUDED.predicted_at,
|
||||
direction = EXCLUDED.direction,
|
||||
prob_up = EXCLUDED.prob_up,
|
||||
prob_flat = EXCLUDED.prob_flat,
|
||||
prob_down = EXCLUDED.prob_down,
|
||||
expected_return = EXCLUDED.expected_return,
|
||||
point_forecast = EXCLUDED.point_forecast,
|
||||
ci_low = EXCLUDED.ci_low,
|
||||
ci_high = EXCLUDED.ci_high,
|
||||
features_snapshot = EXCLUDED.features_snapshot,
|
||||
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"code": code,
|
||||
"predicted_at": now,
|
||||
"base_date": base_date,
|
||||
"target_date": target_date,
|
||||
"horizon": step.horizon,
|
||||
"direction": step.direction,
|
||||
"p_up": step.prob_up,
|
||||
"p_fl": step.prob_flat,
|
||||
"p_dn": step.prob_down,
|
||||
"exp_ret": step.expected_return,
|
||||
"point": step.point_close,
|
||||
"lo": step.ci_low,
|
||||
"hi": step.ci_high,
|
||||
"feats": json.dumps(features_snap),
|
||||
"ut": user_triggered,
|
||||
},
|
||||
).first()
|
||||
if row:
|
||||
saved_ids.append(int(row[0]))
|
||||
|
||||
return {
|
||||
"code": code,
|
||||
"base_date": str(base_date),
|
||||
"base_close": pred.base_close,
|
||||
"sources_used": pred.sources_used,
|
||||
"steps": [
|
||||
{
|
||||
**asdict(s),
|
||||
"target_date": str(_next_trading_target(base_date, s.horizon)),
|
||||
}
|
||||
for s in pred.steps
|
||||
],
|
||||
"saved_prediction_ids": saved_ids,
|
||||
"user_triggered": user_triggered,
|
||||
}
|
||||
122
backend/app/pipelines/retrain_weekly.py
Normal file
122
backend/app/pipelines/retrain_weekly.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""주간 재학습 + 앙상블 가중치 보정.
|
||||
|
||||
일요일 02:00 KST 실행:
|
||||
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
|
||||
2. 최근 30일 prediction_outcomes 의 model 별 hit_rate 산출, model_performance 적재.
|
||||
3. 같은 30일 윈도우에서 chronos vs lgbm hit_rate 로 ensemble_weights 갱신.
|
||||
(방법: w_chronos = clamp(0.1, hr_c / (hr_c + hr_l), 0.9), w_lgbm = 1 - w_chronos.
|
||||
hit_rate 데이터가 부족하면 default 0.6/0.4 유지.)
|
||||
|
||||
지금은 'ensemble' 모델 단일 종류로 predictions 가 쌓이므로, 가중치 보정은
|
||||
chronos 단독 시뮬레이션 / lgbm 단독 시뮬레이션 hit_rate 비교가 진정한 방식인데,
|
||||
Phase 4 단순화: 'ensemble' 의 종합 hit_rate 만 model_performance 에 기록하고
|
||||
가중치는 default 유지. 진짜 비교는 Phase 7 (chronos 단독 + lgbm 단독 예측을
|
||||
shadow 로 같이 적재하는 구조) 로 확장.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.connection import get_engine
|
||||
from app.models.lgbm import train_one
|
||||
from app.seed.seed_tickers import SEED_TICKERS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HORIZONS = (1, 3, 5)
|
||||
WINDOW_DAYS = 30
|
||||
|
||||
|
||||
def retrain_all() -> list[dict[str, Any]]:
|
||||
"""시드 10종목 × horizons 학습."""
|
||||
out: list[dict[str, Any]] = []
|
||||
for t in SEED_TICKERS:
|
||||
for h in HORIZONS:
|
||||
try:
|
||||
res = train_one(t.code, h)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("train_one failed for %s h=%d", t.code, h)
|
||||
res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)}
|
||||
out.append(res)
|
||||
return out
|
||||
|
||||
|
||||
def record_performance(as_of: date) -> list[dict[str, Any]]:
|
||||
"""최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별
|
||||
hit_rate / mae 산출, model_performance 에 upsert."""
|
||||
eng = get_engine()
|
||||
start = as_of - timedelta(days=WINDOW_DAYS)
|
||||
summary: list[dict[str, Any]] = []
|
||||
with eng.begin() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT code, model, horizon,
|
||||
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||
AVG(abs_error) AS mae,
|
||||
COUNT(*) AS n
|
||||
FROM prediction_outcomes
|
||||
WHERE resolved_at >= :start
|
||||
GROUP BY code, model, horizon
|
||||
"""
|
||||
),
|
||||
{"start": start},
|
||||
).all()
|
||||
for code, model, horizon, hr, mae, n in rows:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO model_performance
|
||||
(code, model, window_days, as_of, hit_rate, mae, sample_count)
|
||||
VALUES (:c, :m, :w, :as_of, :hr, :mae, :n)
|
||||
ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET
|
||||
hit_rate = EXCLUDED.hit_rate,
|
||||
mae = EXCLUDED.mae,
|
||||
sample_count = EXCLUDED.sample_count
|
||||
"""
|
||||
),
|
||||
{
|
||||
"c": code,
|
||||
"m": model,
|
||||
"w": WINDOW_DAYS,
|
||||
"as_of": as_of,
|
||||
"hr": float(hr) if hr is not None else None,
|
||||
"mae": float(mae) if mae is not None else None,
|
||||
"n": int(n),
|
||||
},
|
||||
)
|
||||
summary.append(
|
||||
{
|
||||
"code": code,
|
||||
"model": model,
|
||||
"horizon": horizon,
|
||||
"hit_rate": float(hr) if hr is not None else None,
|
||||
"mae": float(mae) if mae is not None else None,
|
||||
"n": int(n),
|
||||
}
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def run_weekly() -> dict[str, Any]:
|
||||
"""일요일 02:00 KST 호출 entry-point."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
kst = timezone(timedelta(hours=9))
|
||||
as_of = datetime.now(kst).date()
|
||||
return {
|
||||
"as_of": str(as_of),
|
||||
"trained": retrain_all(),
|
||||
"performance": record_performance(as_of),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
out = run_weekly()
|
||||
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|
||||
@@ -22,6 +22,8 @@ from apscheduler.triggers.cron import CronTrigger
|
||||
from pytz import timezone
|
||||
|
||||
from app.pipelines.daily_batch import run_daily_batch
|
||||
from app.pipelines.match_outcomes import match_today
|
||||
from app.pipelines.retrain_weekly import run_weekly
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
KST = timezone("Asia/Seoul")
|
||||
@@ -34,15 +36,34 @@ def start_scheduler() -> BackgroundScheduler:
|
||||
if _scheduler:
|
||||
return _scheduler
|
||||
_scheduler = BackgroundScheduler(timezone=KST)
|
||||
# 16:00 평일: 시드 10종목 EOD/뉴스/공시/거시 갱신
|
||||
_scheduler.add_job(
|
||||
run_daily_batch,
|
||||
CronTrigger(hour=16, minute=0, timezone=KST),
|
||||
CronTrigger(day_of_week="mon-fri", hour=16, minute=0, timezone=KST),
|
||||
id="daily_batch_16",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
# 16:30 평일: prediction_outcomes 매칭 배치
|
||||
_scheduler.add_job(
|
||||
match_today,
|
||||
CronTrigger(day_of_week="mon-fri", hour=16, minute=30, timezone=KST),
|
||||
id="match_outcomes_1630",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
# 일요일 02:00: LGBM 재학습 + 성능 기록
|
||||
_scheduler.add_job(
|
||||
run_weekly,
|
||||
CronTrigger(day_of_week="sun", hour=2, minute=0, timezone=KST),
|
||||
id="retrain_weekly_sun_0200",
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
)
|
||||
_scheduler.start()
|
||||
logger.info("scheduler started (daily_batch @ 16:00 KST)")
|
||||
logger.info(
|
||||
"scheduler started: daily_batch(16:00 mon-fri), match_outcomes(16:30 mon-fri), retrain_weekly(sun 02:00) KST"
|
||||
)
|
||||
return _scheduler
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user