- ensemble.predict() 가 chronos_raw / lgbm_raw 를 함께 반환
- predict_and_store() 가 매 호출마다 3종 행 적재:
model='ensemble' (user_triggered=인자)
model='chronos' (user_triggered=FALSE, shadow)
model='lgbm' (user_triggered=FALSE, shadow)
- retrain_weekly.adjust_weights(): 최근 30일 prediction_outcomes 의
chronos vs lgbm hit_rate 로 ensemble_weights upsert
w_chronos = clamp(0.1, hr_c/(hr_c+hr_l), 0.9), w_lgbm = 1 - w_chronos
모델별 표본 < 10 이면 기본값(0.6/0.4) 유지
- API 응답에 saved_shadow_ids 추가 (TS 타입도 동기화)
- README: 동작 모델 메모 섹션을 실제 구현과 일치하도록 갱신
리뷰어 지적 3번 (ensemble_weights 가 영원히 갱신 안됨, upsert_weights 미호출) 해결.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
208 lines
7.4 KiB
Python
208 lines
7.4 KiB
Python
"""주간 재학습 + 앙상블 가중치 보정.
|
||
|
||
일요일 02:00 KST 실행:
|
||
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
|
||
2. 최근 30일 prediction_outcomes 의 (code, model, horizon) 별 hit_rate / mae
|
||
산출, model_performance 적재.
|
||
3. shadow 행 (model='chronos' / 'lgbm') 의 hit_rate 를 비교해서
|
||
ensemble_weights 자동 보정.
|
||
|
||
가중치 공식:
|
||
w_c = clamp(0.1, hr_c / (hr_c + hr_l), 0.9)
|
||
w_l = 1 - w_c
|
||
단 sample_count_c < MIN_SAMPLE 또는 sample_count_l < MIN_SAMPLE 이면
|
||
기본값 유지 (DB row 미생성). hr_c + hr_l == 0 (둘 다 0%) 이면 50:50.
|
||
|
||
predict_one 이 매 호출마다 chronos/lgbm shadow 행을 함께 적재하고
|
||
match_outcomes 가 user_triggered 무관하게 매칭하므로, hit_rate 데이터는
|
||
사용자가 예측을 한 번이라도 본 종목에 대해 자연스럽게 쌓인다.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import json
|
||
import logging
|
||
from datetime import date, timedelta
|
||
from typing import Any
|
||
|
||
from sqlalchemy import text
|
||
|
||
from app.db.connection import get_engine
|
||
from app.models.lgbm import train_one
|
||
from app.models.weights import upsert_weights
|
||
from app.seed.seed_tickers import SEED_TICKERS
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
HORIZONS = (1, 3, 5)
|
||
WINDOW_DAYS = 30
|
||
MIN_SAMPLE = 10 # 모델당 최소 매칭 표본
|
||
W_CHRONOS_MIN = 0.1
|
||
W_CHRONOS_MAX = 0.9
|
||
|
||
|
||
def retrain_all() -> list[dict[str, Any]]:
|
||
"""시드 10종목 × horizons 학습."""
|
||
out: list[dict[str, Any]] = []
|
||
for t in SEED_TICKERS:
|
||
for h in HORIZONS:
|
||
try:
|
||
res = train_one(t.code, h)
|
||
except Exception as exc: # noqa: BLE001
|
||
logger.exception("train_one failed for %s h=%d", t.code, h)
|
||
res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)}
|
||
out.append(res)
|
||
return out
|
||
|
||
|
||
def record_performance(as_of: date) -> list[dict[str, Any]]:
|
||
"""최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별
|
||
hit_rate / mae 산출, model_performance 에 upsert."""
|
||
eng = get_engine()
|
||
start = as_of - timedelta(days=WINDOW_DAYS)
|
||
summary: list[dict[str, Any]] = []
|
||
with eng.begin() as conn:
|
||
rows = conn.execute(
|
||
text(
|
||
"""
|
||
SELECT code, model, horizon,
|
||
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||
AVG(abs_error) AS mae,
|
||
COUNT(*) AS n
|
||
FROM prediction_outcomes
|
||
WHERE resolved_at >= :start
|
||
GROUP BY code, model, horizon
|
||
"""
|
||
),
|
||
{"start": start},
|
||
).all()
|
||
for code, model, horizon, hr, mae, n in rows:
|
||
conn.execute(
|
||
text(
|
||
"""
|
||
INSERT INTO model_performance
|
||
(code, model, window_days, as_of, hit_rate, mae, sample_count)
|
||
VALUES (:c, :m, :w, :as_of, :hr, :mae, :n)
|
||
ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET
|
||
hit_rate = EXCLUDED.hit_rate,
|
||
mae = EXCLUDED.mae,
|
||
sample_count = EXCLUDED.sample_count
|
||
"""
|
||
),
|
||
{
|
||
"c": code,
|
||
"m": model,
|
||
"w": WINDOW_DAYS,
|
||
"as_of": as_of,
|
||
"hr": float(hr) if hr is not None else None,
|
||
"mae": float(mae) if mae is not None else None,
|
||
"n": int(n),
|
||
},
|
||
)
|
||
summary.append(
|
||
{
|
||
"code": code,
|
||
"model": model,
|
||
"horizon": horizon,
|
||
"hit_rate": float(hr) if hr is not None else None,
|
||
"mae": float(mae) if mae is not None else None,
|
||
"n": int(n),
|
||
}
|
||
)
|
||
return summary
|
||
|
||
|
||
def adjust_weights(as_of: date) -> list[dict[str, Any]]:
|
||
"""shadow chronos/lgbm hit_rate 로 ensemble_weights 자동 보정.
|
||
|
||
반환: (code, horizon, w_chronos, w_lgbm, hr_c, hr_l, n_c, n_l, action) 의
|
||
리스트. action ∈ {'updated', 'skipped_insufficient', 'skipped_zero'}.
|
||
"""
|
||
eng = get_engine()
|
||
start = as_of - timedelta(days=WINDOW_DAYS)
|
||
out: list[dict[str, Any]] = []
|
||
with eng.begin() as conn:
|
||
rows = conn.execute(
|
||
text(
|
||
"""
|
||
SELECT code, horizon, model,
|
||
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||
COUNT(*) AS n
|
||
FROM prediction_outcomes
|
||
WHERE resolved_at >= :start
|
||
AND model IN ('chronos', 'lgbm')
|
||
GROUP BY code, horizon, model
|
||
"""
|
||
),
|
||
{"start": start},
|
||
).all()
|
||
|
||
# (code, horizon) -> {'chronos': (hr, n), 'lgbm': (hr, n)}
|
||
agg: dict[tuple[str, int], dict[str, tuple[float, int]]] = {}
|
||
for code, horizon, model, hr, n in rows:
|
||
key = (code, int(horizon))
|
||
agg.setdefault(key, {})[str(model)] = (
|
||
float(hr) if hr is not None else 0.0,
|
||
int(n),
|
||
)
|
||
|
||
for (code, horizon), m in agg.items():
|
||
c = m.get("chronos")
|
||
l = m.get("lgbm")
|
||
if c is None or l is None:
|
||
out.append({
|
||
"code": code, "horizon": horizon,
|
||
"action": "skipped_missing_model",
|
||
"have": list(m.keys()),
|
||
})
|
||
continue
|
||
hr_c, n_c = c
|
||
hr_l, n_l = l
|
||
if n_c < MIN_SAMPLE or n_l < MIN_SAMPLE:
|
||
out.append({
|
||
"code": code, "horizon": horizon,
|
||
"hr_chronos": hr_c, "hr_lgbm": hr_l,
|
||
"n_chronos": n_c, "n_lgbm": n_l,
|
||
"action": "skipped_insufficient",
|
||
})
|
||
continue
|
||
total = hr_c + hr_l
|
||
if total <= 0:
|
||
w_c = 0.5
|
||
else:
|
||
w_c = hr_c / total
|
||
w_c = max(W_CHRONOS_MIN, min(W_CHRONOS_MAX, w_c))
|
||
w_l = 1.0 - w_c
|
||
upsert_weights(
|
||
code, horizon, w_c, w_l,
|
||
hit_rate_chronos=hr_c, hit_rate_lgbm=hr_l,
|
||
sample_count=min(n_c, n_l),
|
||
)
|
||
out.append({
|
||
"code": code, "horizon": horizon,
|
||
"hr_chronos": hr_c, "hr_lgbm": hr_l,
|
||
"n_chronos": n_c, "n_lgbm": n_l,
|
||
"w_chronos": w_c, "w_lgbm": w_l,
|
||
"action": "updated",
|
||
})
|
||
return out
|
||
|
||
|
||
def run_weekly() -> dict[str, Any]:
|
||
"""일요일 02:00 KST 호출 entry-point."""
|
||
from datetime import datetime, timezone
|
||
|
||
kst = timezone(timedelta(hours=9))
|
||
as_of = datetime.now(kst).date()
|
||
return {
|
||
"as_of": str(as_of),
|
||
"trained": retrain_all(),
|
||
"performance": record_performance(as_of),
|
||
"weights": adjust_weights(as_of),
|
||
}
|
||
|
||
|
||
if __name__ == "__main__":
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||
out = run_weekly()
|
||
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|