- backend/app/models/lgbm.py: 종목 × horizon 별 LightGBM 회귀(y_ret_h)
+ 다중분류(y_dir_h, 3-class). joblib 으로 backend/data/models/{code}_h{H}_*.pkl
저장. early_stopping(30). predict_one() 으로 최신 영업일 피처에 추론.
- backend/app/models/weights.py: ensemble_weights 테이블 IO,
default w_chronos=0.6 / w_lgbm=0.4 (DB 행 없으면 fallback).
- backend/app/models/ensemble.py: Chronos sample 분포 + LGBM regression+cls
결합. point/q10/q90 + prob_up/flat/down + direction 라벨. 한쪽 모델
실패 시 다른 쪽 단독 fallback (cold start: chronos 단독).
- backend/app/pipelines/predict_one.py: predict_and_store(). 결과를
predictions 테이블에 UPSERT, user_triggered 누적 OR. base_date = 마지막
ohlcv 거래일, target_date = base_date + H 영업일(주말 스킵, 공휴일은
매칭잡에서 자연 보정).
- backend/app/pipelines/match_outcomes.py: target_date == d 인
user_triggered=TRUE 예측을 d 의 실제 종가와 매칭 → prediction_outcomes
적재. direction_hit(±0.3% flat band) + abs_error. 실제 종가 없으면
자연 skip.
- backend/app/pipelines/retrain_weekly.py: 시드 10종목 × H 재학습 +
최근 30일 model_performance 적재.
- backend/app/db/migrations/003_ensemble_weights.sql: (code, horizon) →
(w_chronos, w_lgbm, hit_rate_*, sample_count).
- backend/app/pipelines/scheduler.py:
daily_batch : 평일 16:00 KST
match_outcomes : 평일 16:30 KST ← 사용자가 확정한 매칭 시점
retrain_weekly : 일요일 02:00 KST
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
123 lines
4.5 KiB
Python
123 lines
4.5 KiB
Python
"""주간 재학습 + 앙상블 가중치 보정.
|
||
|
||
일요일 02:00 KST 실행:
|
||
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
|
||
2. 최근 30일 prediction_outcomes 의 model 별 hit_rate 산출, model_performance 적재.
|
||
3. 같은 30일 윈도우에서 chronos vs lgbm hit_rate 로 ensemble_weights 갱신.
|
||
(방법: w_chronos = clamp(0.1, hr_c / (hr_c + hr_l), 0.9), w_lgbm = 1 - w_chronos.
|
||
hit_rate 데이터가 부족하면 default 0.6/0.4 유지.)
|
||
|
||
지금은 'ensemble' 모델 단일 종류로 predictions 가 쌓이므로, 가중치 보정은
|
||
chronos 단독 시뮬레이션 / lgbm 단독 시뮬레이션 hit_rate 비교가 진정한 방식인데,
|
||
Phase 4 단순화: 'ensemble' 의 종합 hit_rate 만 model_performance 에 기록하고
|
||
가중치는 default 유지. 진짜 비교는 Phase 7 (chronos 단독 + lgbm 단독 예측을
|
||
shadow 로 같이 적재하는 구조) 로 확장.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from datetime import date, timedelta
|
||
from typing import Any
|
||
|
||
from sqlalchemy import text
|
||
|
||
from app.db.connection import get_engine
|
||
from app.models.lgbm import train_one
|
||
from app.seed.seed_tickers import SEED_TICKERS
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
HORIZONS = (1, 3, 5)
|
||
WINDOW_DAYS = 30
|
||
|
||
|
||
def retrain_all() -> list[dict[str, Any]]:
|
||
"""시드 10종목 × horizons 학습."""
|
||
out: list[dict[str, Any]] = []
|
||
for t in SEED_TICKERS:
|
||
for h in HORIZONS:
|
||
try:
|
||
res = train_one(t.code, h)
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.exception("train_one failed for %s h=%d", t.code, h)
|
||
res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)}
|
||
out.append(res)
|
||
return out
|
||
|
||
|
||
def record_performance(as_of: date) -> list[dict[str, Any]]:
|
||
"""최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별
|
||
hit_rate / mae 산출, model_performance 에 upsert."""
|
||
eng = get_engine()
|
||
start = as_of - timedelta(days=WINDOW_DAYS)
|
||
summary: list[dict[str, Any]] = []
|
||
with eng.begin() as conn:
|
||
rows = conn.execute(
|
||
text(
|
||
"""
|
||
SELECT code, model, horizon,
|
||
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||
AVG(abs_error) AS mae,
|
||
COUNT(*) AS n
|
||
FROM prediction_outcomes
|
||
WHERE resolved_at >= :start
|
||
GROUP BY code, model, horizon
|
||
"""
|
||
),
|
||
{"start": start},
|
||
).all()
|
||
for code, model, horizon, hr, mae, n in rows:
|
||
conn.execute(
|
||
text(
|
||
"""
|
||
INSERT INTO model_performance
|
||
(code, model, window_days, as_of, hit_rate, mae, sample_count)
|
||
VALUES (:c, :m, :w, :as_of, :hr, :mae, :n)
|
||
ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET
|
||
hit_rate = EXCLUDED.hit_rate,
|
||
mae = EXCLUDED.mae,
|
||
sample_count = EXCLUDED.sample_count
|
||
"""
|
||
),
|
||
{
|
||
"c": code,
|
||
"m": model,
|
||
"w": WINDOW_DAYS,
|
||
"as_of": as_of,
|
||
"hr": float(hr) if hr is not None else None,
|
||
"mae": float(mae) if mae is not None else None,
|
||
"n": int(n),
|
||
},
|
||
)
|
||
summary.append(
|
||
{
|
||
"code": code,
|
||
"model": model,
|
||
"horizon": horizon,
|
||
"hit_rate": float(hr) if hr is not None else None,
|
||
"mae": float(mae) if mae is not None else None,
|
||
"n": int(n),
|
||
}
|
||
)
|
||
return summary
|
||
|
||
|
||
def run_weekly() -> dict[str, Any]:
|
||
"""일요일 02:00 KST 호출 entry-point."""
|
||
from datetime import datetime, timezone
|
||
|
||
kst = timezone(timedelta(hours=9))
|
||
as_of = datetime.now(kst).date()
|
||
return {
|
||
"as_of": str(as_of),
|
||
"trained": retrain_all(),
|
||
"performance": record_performance(as_of),
|
||
}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||
out = run_weekly()
|
||
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|