Files
stock_chart_site/backend/app/pipelines/retrain_weekly.py
tkrmagid bf4fb01146 feat(phase-4): LGBM 모델 + 앙상블 + 매칭/재학습 잡
- backend/app/models/lgbm.py: 종목 × horizon 별 LightGBM 회귀(y_ret_h)
  + 다중분류(y_dir_h, 3-class). joblib 으로 backend/data/models/{code}_h{H}_*.pkl
  저장. early_stopping(30). predict_one() 으로 최신 영업일 피처에 추론.
- backend/app/models/weights.py: ensemble_weights 테이블 IO,
  default w_chronos=0.6 / w_lgbm=0.4 (DB 행 없으면 fallback).
- backend/app/models/ensemble.py: Chronos sample 분포 + LGBM regression+cls
  결합. point/q10/q90 + prob_up/flat/down + direction 라벨. 한쪽 모델
  실패 시 다른 쪽 단독 fallback (cold start: chronos 단독).
- backend/app/pipelines/predict_one.py: predict_and_store(). 결과를
  predictions 테이블에 UPSERT, user_triggered 누적 OR. base_date = 마지막
  ohlcv 거래일, target_date = base_date + H 영업일(주말 스킵, 공휴일은
  매칭잡에서 자연 보정).
- backend/app/pipelines/match_outcomes.py: target_date == d 인
  user_triggered=TRUE 예측을 d 의 실제 종가와 매칭 → prediction_outcomes
  적재. direction_hit(±0.3% flat band) + abs_error. 실제 종가 없으면
  자연 skip.
- backend/app/pipelines/retrain_weekly.py: 시드 10종목 × H 재학습 +
  최근 30일 model_performance 적재.
- backend/app/db/migrations/003_ensemble_weights.sql: (code, horizon) →
  (w_chronos, w_lgbm, hit_rate_*, sample_count).
- backend/app/pipelines/scheduler.py:
    daily_batch    : 평일 16:00 KST
    match_outcomes : 평일 16:30 KST   ← 사용자가 확정한 매칭 시점
    retrain_weekly : 일요일 02:00 KST

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-20 16:03:01 +09:00

123 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""주간 재학습 + 앙상블 가중치 보정.
일요일 02:00 KST 실행:
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
2. 최근 30일 prediction_outcomes 의 model 별 hit_rate 산출, model_performance 적재.
3. 같은 30일 윈도우에서 chronos vs lgbm hit_rate 로 ensemble_weights 갱신.
(방법: w_chronos = clamp(0.1, hr_c / (hr_c + hr_l), 0.9), w_lgbm = 1 - w_chronos.
hit_rate 데이터가 부족하면 default 0.6/0.4 유지.)
지금은 'ensemble' 모델 단일 종류로 predictions 가 쌓이므로, 가중치 보정은
chronos 단독 시뮬레이션 / lgbm 단독 시뮬레이션 hit_rate 비교가 진정한 방식인데,
Phase 4 단순화: 'ensemble' 의 종합 hit_rate 만 model_performance 에 기록하고
가중치는 default 유지. 진짜 비교는 Phase 7 (chronos 단독 + lgbm 단독 예측을
shadow 로 같이 적재하는 구조) 로 확장.
"""
from __future__ import annotations
import json
import logging
from datetime import date, timedelta
from typing import Any
from sqlalchemy import text
from app.db.connection import get_engine
from app.models.lgbm import train_one
from app.seed.seed_tickers import SEED_TICKERS
logger = logging.getLogger(__name__)
HORIZONS = (1, 3, 5)
WINDOW_DAYS = 30
def retrain_all() -> list[dict[str, Any]]:
"""시드 10종목 × horizons 학습."""
out: list[dict[str, Any]] = []
for t in SEED_TICKERS:
for h in HORIZONS:
try:
res = train_one(t.code, h)
except Exception as exc: # noqa: BLE001
logger.exception("train_one failed for %s h=%d", t.code, h)
res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)}
out.append(res)
return out
def record_performance(as_of: date) -> list[dict[str, Any]]:
"""최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별
hit_rate / mae 산출, model_performance 에 upsert."""
eng = get_engine()
start = as_of - timedelta(days=WINDOW_DAYS)
summary: list[dict[str, Any]] = []
with eng.begin() as conn:
rows = conn.execute(
text(
"""
SELECT code, model, horizon,
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
AVG(abs_error) AS mae,
COUNT(*) AS n
FROM prediction_outcomes
WHERE resolved_at >= :start
GROUP BY code, model, horizon
"""
),
{"start": start},
).all()
for code, model, horizon, hr, mae, n in rows:
conn.execute(
text(
"""
INSERT INTO model_performance
(code, model, window_days, as_of, hit_rate, mae, sample_count)
VALUES (:c, :m, :w, :as_of, :hr, :mae, :n)
ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET
hit_rate = EXCLUDED.hit_rate,
mae = EXCLUDED.mae,
sample_count = EXCLUDED.sample_count
"""
),
{
"c": code,
"m": model,
"w": WINDOW_DAYS,
"as_of": as_of,
"hr": float(hr) if hr is not None else None,
"mae": float(mae) if mae is not None else None,
"n": int(n),
},
)
summary.append(
{
"code": code,
"model": model,
"horizon": horizon,
"hit_rate": float(hr) if hr is not None else None,
"mae": float(mae) if mae is not None else None,
"n": int(n),
}
)
return summary
def run_weekly() -> dict[str, Any]:
"""일요일 02:00 KST 호출 entry-point."""
from datetime import datetime, timezone
kst = timezone(timedelta(hours=9))
as_of = datetime.now(kst).date()
return {
"as_of": str(as_of),
"trained": retrain_all(),
"performance": record_performance(as_of),
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
out = run_weekly()
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))