"""주간 재학습 + 앙상블 가중치 보정. 일요일 02:00 KST 실행: 1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one). 2. 최근 30일 prediction_outcomes 의 (code, model, horizon) 별 hit_rate / mae 산출, model_performance 적재. 3. shadow 행 (model='chronos' / 'lgbm') 의 hit_rate 를 비교해서 ensemble_weights 자동 보정. 가중치 공식: w_c = clamp(0.1, hr_c / (hr_c + hr_l), 0.9) w_l = 1 - w_c 단 sample_count_c < MIN_SAMPLE 또는 sample_count_l < MIN_SAMPLE 이면 기본값 유지 (DB row 미생성). hr_c + hr_l == 0 (둘 다 0%) 이면 50:50. predict_one 이 매 호출마다 chronos/lgbm shadow 행을 함께 적재하고 match_outcomes 가 user_triggered 무관하게 매칭하므로, hit_rate 데이터는 사용자가 예측을 한 번이라도 본 종목에 대해 자연스럽게 쌓인다. """ from __future__ import annotations import json import logging from datetime import date, timedelta from typing import Any from sqlalchemy import text from app.db.connection import get_engine from app.models.lgbm import train_one from app.models.weights import upsert_weights from app.seed.seed_tickers import SEED_TICKERS logger = logging.getLogger(__name__) HORIZONS = (1, 3, 5) WINDOW_DAYS = 30 MIN_SAMPLE = 10 # 모델당 최소 매칭 표본 W_CHRONOS_MIN = 0.1 W_CHRONOS_MAX = 0.9 def retrain_all() -> list[dict[str, Any]]: """시드 10종목 × horizons 학습.""" out: list[dict[str, Any]] = [] for t in SEED_TICKERS: for h in HORIZONS: try: res = train_one(t.code, h) except Exception as exc: # noqa: BLE001 logger.exception("train_one failed for %s h=%d", t.code, h) res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)} out.append(res) return out def record_performance(as_of: date) -> list[dict[str, Any]]: """최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별 hit_rate / mae 산출, model_performance 에 upsert.""" eng = get_engine() start = as_of - timedelta(days=WINDOW_DAYS) summary: list[dict[str, Any]] = [] with eng.begin() as conn: rows = conn.execute( text( """ SELECT code, model, horizon, AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate, AVG(abs_error) AS mae, COUNT(*) AS n FROM prediction_outcomes WHERE resolved_at >= :start GROUP BY code, model, horizon """ ), {"start": start}, ).all() for code, model, horizon, hr, mae, n in rows: conn.execute( text( """ INSERT INTO model_performance (code, model, window_days, as_of, hit_rate, mae, sample_count) VALUES (:c, :m, :w, :as_of, :hr, :mae, :n) ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET hit_rate = EXCLUDED.hit_rate, mae = EXCLUDED.mae, sample_count = EXCLUDED.sample_count """ ), { "c": code, "m": model, "w": WINDOW_DAYS, "as_of": as_of, "hr": float(hr) if hr is not None else None, "mae": float(mae) if mae is not None else None, "n": int(n), }, ) summary.append( { "code": code, "model": model, "horizon": horizon, "hit_rate": float(hr) if hr is not None else None, "mae": float(mae) if mae is not None else None, "n": int(n), } ) return summary def adjust_weights(as_of: date) -> list[dict[str, Any]]: """shadow chronos/lgbm hit_rate 로 ensemble_weights 자동 보정. 반환: (code, horizon, w_chronos, w_lgbm, hr_c, hr_l, n_c, n_l, action) 의 리스트. action ∈ {'updated', 'skipped_insufficient', 'skipped_zero'}. """ eng = get_engine() start = as_of - timedelta(days=WINDOW_DAYS) out: list[dict[str, Any]] = [] with eng.begin() as conn: rows = conn.execute( text( """ SELECT code, horizon, model, AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate, COUNT(*) AS n FROM prediction_outcomes WHERE resolved_at >= :start AND model IN ('chronos', 'lgbm') GROUP BY code, horizon, model """ ), {"start": start}, ).all() # (code, horizon) -> {'chronos': (hr, n), 'lgbm': (hr, n)} agg: dict[tuple[str, int], dict[str, tuple[float, int]]] = {} for code, horizon, model, hr, n in rows: key = (code, int(horizon)) agg.setdefault(key, {})[str(model)] = ( float(hr) if hr is not None else 0.0, int(n), ) for (code, horizon), m in agg.items(): c = m.get("chronos") l = m.get("lgbm") if c is None or l is None: out.append({ "code": code, "horizon": horizon, "action": "skipped_missing_model", "have": list(m.keys()), }) continue hr_c, n_c = c hr_l, n_l = l if n_c < MIN_SAMPLE or n_l < MIN_SAMPLE: out.append({ "code": code, "horizon": horizon, "hr_chronos": hr_c, "hr_lgbm": hr_l, "n_chronos": n_c, "n_lgbm": n_l, "action": "skipped_insufficient", }) continue total = hr_c + hr_l if total <= 0: w_c = 0.5 else: w_c = hr_c / total w_c = max(W_CHRONOS_MIN, min(W_CHRONOS_MAX, w_c)) w_l = 1.0 - w_c upsert_weights( code, horizon, w_c, w_l, hit_rate_chronos=hr_c, hit_rate_lgbm=hr_l, sample_count=min(n_c, n_l), ) out.append({ "code": code, "horizon": horizon, "hr_chronos": hr_c, "hr_lgbm": hr_l, "n_chronos": n_c, "n_lgbm": n_l, "w_chronos": w_c, "w_lgbm": w_l, "action": "updated", }) return out def run_weekly() -> dict[str, Any]: """일요일 02:00 KST 호출 entry-point.""" from datetime import datetime, timezone kst = timezone(timedelta(hours=9)) as_of = datetime.now(kst).date() return { "as_of": str(as_of), "trained": retrain_all(), "performance": record_performance(as_of), "weights": adjust_weights(as_of), } if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") out = run_weekly() print(json.dumps(out, ensure_ascii=False, indent=2, default=str))