feat(phase-4): LGBM 모델 + 앙상블 + 매칭/재학습 잡
- backend/app/models/lgbm.py: 종목 × horizon 별 LightGBM 회귀(y_ret_h)
+ 다중분류(y_dir_h, 3-class). joblib 으로 backend/data/models/{code}_h{H}_*.pkl
저장. early_stopping(30). predict_one() 으로 최신 영업일 피처에 추론.
- backend/app/models/weights.py: ensemble_weights 테이블 IO,
default w_chronos=0.6 / w_lgbm=0.4 (DB 행 없으면 fallback).
- backend/app/models/ensemble.py: Chronos sample 분포 + LGBM regression+cls
결합. point/q10/q90 + prob_up/flat/down + direction 라벨. 한쪽 모델
실패 시 다른 쪽 단독 fallback (cold start: chronos 단독).
- backend/app/pipelines/predict_one.py: predict_and_store(). 결과를
predictions 테이블에 UPSERT, user_triggered 누적 OR. base_date = 마지막
ohlcv 거래일, target_date = base_date + H 영업일(주말 스킵, 공휴일은
매칭잡에서 자연 보정).
- backend/app/pipelines/match_outcomes.py: target_date == d 인
user_triggered=TRUE 예측을 d 의 실제 종가와 매칭 → prediction_outcomes
적재. direction_hit(±0.3% flat band) + abs_error. 실제 종가 없으면
자연 skip.
- backend/app/pipelines/retrain_weekly.py: 시드 10종목 × H 재학습 +
최근 30일 model_performance 적재.
- backend/app/db/migrations/003_ensemble_weights.sql: (code, horizon) →
(w_chronos, w_lgbm, hit_rate_*, sample_count).
- backend/app/pipelines/scheduler.py:
daily_batch : 평일 16:00 KST
match_outcomes : 평일 16:30 KST ← 사용자가 확정한 매칭 시점
retrain_weekly : 일요일 02:00 KST
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
122
backend/app/pipelines/retrain_weekly.py
Normal file
122
backend/app/pipelines/retrain_weekly.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""주간 재학습 + 앙상블 가중치 보정.
|
||||
|
||||
일요일 02:00 KST 실행:
|
||||
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
|
||||
2. 최근 30일 prediction_outcomes 의 model 별 hit_rate 산출, model_performance 적재.
|
||||
3. 같은 30일 윈도우에서 chronos vs lgbm hit_rate 로 ensemble_weights 갱신.
|
||||
(방법: w_chronos = clamp(0.1, hr_c / (hr_c + hr_l), 0.9), w_lgbm = 1 - w_chronos.
|
||||
hit_rate 데이터가 부족하면 default 0.6/0.4 유지.)
|
||||
|
||||
지금은 'ensemble' 모델 단일 종류로 predictions 가 쌓이므로, 가중치 보정은
|
||||
chronos 단독 시뮬레이션 / lgbm 단독 시뮬레이션 hit_rate 비교가 진정한 방식인데,
|
||||
Phase 4 단순화: 'ensemble' 의 종합 hit_rate 만 model_performance 에 기록하고
|
||||
가중치는 default 유지. 진짜 비교는 Phase 7 (chronos 단독 + lgbm 단독 예측을
|
||||
shadow 로 같이 적재하는 구조) 로 확장.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.connection import get_engine
|
||||
from app.models.lgbm import train_one
|
||||
from app.seed.seed_tickers import SEED_TICKERS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HORIZONS = (1, 3, 5)
|
||||
WINDOW_DAYS = 30
|
||||
|
||||
|
||||
def retrain_all() -> list[dict[str, Any]]:
|
||||
"""시드 10종목 × horizons 학습."""
|
||||
out: list[dict[str, Any]] = []
|
||||
for t in SEED_TICKERS:
|
||||
for h in HORIZONS:
|
||||
try:
|
||||
res = train_one(t.code, h)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("train_one failed for %s h=%d", t.code, h)
|
||||
res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)}
|
||||
out.append(res)
|
||||
return out
|
||||
|
||||
|
||||
def record_performance(as_of: date) -> list[dict[str, Any]]:
|
||||
"""최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별
|
||||
hit_rate / mae 산출, model_performance 에 upsert."""
|
||||
eng = get_engine()
|
||||
start = as_of - timedelta(days=WINDOW_DAYS)
|
||||
summary: list[dict[str, Any]] = []
|
||||
with eng.begin() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT code, model, horizon,
|
||||
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||
AVG(abs_error) AS mae,
|
||||
COUNT(*) AS n
|
||||
FROM prediction_outcomes
|
||||
WHERE resolved_at >= :start
|
||||
GROUP BY code, model, horizon
|
||||
"""
|
||||
),
|
||||
{"start": start},
|
||||
).all()
|
||||
for code, model, horizon, hr, mae, n in rows:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO model_performance
|
||||
(code, model, window_days, as_of, hit_rate, mae, sample_count)
|
||||
VALUES (:c, :m, :w, :as_of, :hr, :mae, :n)
|
||||
ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET
|
||||
hit_rate = EXCLUDED.hit_rate,
|
||||
mae = EXCLUDED.mae,
|
||||
sample_count = EXCLUDED.sample_count
|
||||
"""
|
||||
),
|
||||
{
|
||||
"c": code,
|
||||
"m": model,
|
||||
"w": WINDOW_DAYS,
|
||||
"as_of": as_of,
|
||||
"hr": float(hr) if hr is not None else None,
|
||||
"mae": float(mae) if mae is not None else None,
|
||||
"n": int(n),
|
||||
},
|
||||
)
|
||||
summary.append(
|
||||
{
|
||||
"code": code,
|
||||
"model": model,
|
||||
"horizon": horizon,
|
||||
"hit_rate": float(hr) if hr is not None else None,
|
||||
"mae": float(mae) if mae is not None else None,
|
||||
"n": int(n),
|
||||
}
|
||||
)
|
||||
return summary
|
||||
|
||||
|
||||
def run_weekly() -> dict[str, Any]:
|
||||
"""일요일 02:00 KST 호출 entry-point."""
|
||||
from datetime import datetime, timezone
|
||||
|
||||
kst = timezone(timedelta(hours=9))
|
||||
as_of = datetime.now(kst).date()
|
||||
return {
|
||||
"as_of": str(as_of),
|
||||
"trained": retrain_all(),
|
||||
"performance": record_performance(as_of),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
out = run_weekly()
|
||||
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|
||||
Reference in New Issue
Block a user