feat(phase-4): LGBM 모델 + 앙상블 + 매칭/재학습 잡
- backend/app/models/lgbm.py: 종목 × horizon 별 LightGBM 회귀(y_ret_h)
+ 다중분류(y_dir_h, 3-class). joblib 으로 backend/data/models/{code}_h{H}_*.pkl
저장. early_stopping(30). predict_one() 으로 최신 영업일 피처에 추론.
- backend/app/models/weights.py: ensemble_weights 테이블 IO,
default w_chronos=0.6 / w_lgbm=0.4 (DB 행 없으면 fallback).
- backend/app/models/ensemble.py: Chronos sample 분포 + LGBM regression+cls
결합. point/q10/q90 + prob_up/flat/down + direction 라벨. 한쪽 모델
실패 시 다른 쪽 단독 fallback (cold start: chronos 단독).
- backend/app/pipelines/predict_one.py: predict_and_store(). 결과를
predictions 테이블에 UPSERT, user_triggered 누적 OR. base_date = 마지막
ohlcv 거래일, target_date = base_date + H 영업일(주말 스킵, 공휴일은
매칭잡에서 자연 보정).
- backend/app/pipelines/match_outcomes.py: target_date == d 인
user_triggered=TRUE 예측을 d 의 실제 종가와 매칭 → prediction_outcomes
적재. direction_hit(±0.3% flat band) + abs_error. 실제 종가 없으면
자연 skip.
- backend/app/pipelines/retrain_weekly.py: 시드 10종목 × H 재학습 +
최근 30일 model_performance 적재.
- backend/app/db/migrations/003_ensemble_weights.sql: (code, horizon) →
(w_chronos, w_lgbm, hit_rate_*, sample_count).
- backend/app/pipelines/scheduler.py:
daily_batch : 평일 16:00 KST
match_outcomes : 평일 16:30 KST ← 사용자가 확정한 매칭 시점
retrain_weekly : 일요일 02:00 KST
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
165
backend/app/pipelines/match_outcomes.py
Normal file
165
backend/app/pipelines/match_outcomes.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""prediction_outcomes 매칭 배치.
|
||||
|
||||
평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의
|
||||
확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, target_date == 오늘인
|
||||
user_triggered=TRUE 예측을 그 종가와 매칭.
|
||||
|
||||
cold-start / 휴장일 대비: 인자로 받은 target_date 의 ohlcv_daily 에 종가가
|
||||
없으면 자연스럽게 skip. 다음 거래일 매칭 잡이 다시 시도하면 그 날짜는
|
||||
여전히 매칭되지 않으므로 (매칭 sql 이 target_date 기준), 영원히 매칭 안되는
|
||||
잘못된 calendar date 예측은 cleanup CLI 로 별도 정리 가능 (Phase 7).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from datetime import date, timedelta
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.connection import get_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# direction_hit 판정 시 ±0.3% 이내는 flat. (features 의 FLAT_BAND 와 동일)
|
||||
FLAT_BAND = 0.003
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchSummary:
|
||||
target_date: str
|
||||
candidates: int
|
||||
matched: int
|
||||
skipped_no_actual: int
|
||||
already_resolved: int
|
||||
|
||||
|
||||
def _direction_label(ret: float) -> str:
|
||||
if ret > FLAT_BAND:
|
||||
return "up"
|
||||
if ret < -FLAT_BAND:
|
||||
return "down"
|
||||
return "flat"
|
||||
|
||||
|
||||
def match_for_date(d: date) -> MatchSummary:
|
||||
"""target_date == d 인 user_triggered=TRUE 예측을 매칭."""
|
||||
eng = get_engine()
|
||||
with eng.begin() as conn:
|
||||
# 매칭 대상 예측 + 매칭 안 됐는지 확인.
|
||||
candidate_rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT p.id, p.code, p.base_date, p.horizon, p.point_forecast,
|
||||
p.direction, p.model
|
||||
FROM predictions p
|
||||
LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id
|
||||
WHERE p.target_date = :d
|
||||
AND p.user_triggered = TRUE
|
||||
AND po.prediction_id IS NULL
|
||||
"""
|
||||
),
|
||||
{"d": d},
|
||||
).all()
|
||||
candidates = len(candidate_rows)
|
||||
if not candidates:
|
||||
return MatchSummary(str(d), 0, 0, 0, 0)
|
||||
|
||||
# 종목별로 actual close 조회 (한번에 batch).
|
||||
codes = list({r[1] for r in candidate_rows})
|
||||
actual_map: dict[tuple[str, date], float] = {}
|
||||
for code in codes:
|
||||
row = conn.execute(
|
||||
text(
|
||||
"SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"
|
||||
),
|
||||
{"c": code, "d": d},
|
||||
).first()
|
||||
if row and row[0] is not None:
|
||||
actual_map[(code, d)] = float(row[0])
|
||||
|
||||
# base_close (각 예측의 base_date 종가) 도 필요 — direction 판정용.
|
||||
base_close_map: dict[tuple[str, date], float] = {}
|
||||
for pid, code, base_date, *_ in candidate_rows:
|
||||
key = (code, base_date)
|
||||
if key in base_close_map:
|
||||
continue
|
||||
row = conn.execute(
|
||||
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
||||
{"c": code, "d": base_date},
|
||||
).first()
|
||||
if row and row[0] is not None:
|
||||
base_close_map[key] = float(row[0])
|
||||
|
||||
matched = 0
|
||||
skipped = 0
|
||||
already = 0
|
||||
for pid, code, base_date, horizon, point_forecast, pred_dir, model in candidate_rows:
|
||||
actual = actual_map.get((code, d))
|
||||
base_close = base_close_map.get((code, base_date))
|
||||
if actual is None or base_close is None:
|
||||
skipped += 1
|
||||
continue
|
||||
actual_ret = actual / base_close - 1.0
|
||||
actual_dir = _direction_label(actual_ret)
|
||||
dir_hit = (pred_dir == actual_dir)
|
||||
abs_err = abs(float(point_forecast) - actual) if point_forecast is not None else None
|
||||
|
||||
try:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO prediction_outcomes
|
||||
(prediction_id, code, target_date, horizon, model,
|
||||
predicted_close, actual_close, actual_return, direction_hit, abs_error)
|
||||
VALUES
|
||||
(:pid, :code, :d, :h, :m, :pc, :ac, :ar, :dh, :ae)
|
||||
ON CONFLICT (prediction_id) DO NOTHING
|
||||
"""
|
||||
),
|
||||
{
|
||||
"pid": pid,
|
||||
"code": code,
|
||||
"d": d,
|
||||
"h": horizon,
|
||||
"m": model,
|
||||
"pc": float(point_forecast) if point_forecast is not None else None,
|
||||
"ac": actual,
|
||||
"ar": float(actual_ret),
|
||||
"dh": bool(dir_hit),
|
||||
"ae": float(abs_err) if abs_err is not None else None,
|
||||
},
|
||||
)
|
||||
matched += 1
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("match insert failed pid=%s: %s", pid, exc)
|
||||
already += 1
|
||||
|
||||
return MatchSummary(
|
||||
target_date=str(d),
|
||||
candidates=candidates,
|
||||
matched=matched,
|
||||
skipped_no_actual=skipped,
|
||||
already_resolved=already,
|
||||
)
|
||||
|
||||
|
||||
def match_today() -> dict[str, Any]:
|
||||
"""평일 16:30 KST 호출용. target_date == today (KST) 인 행 매칭."""
|
||||
from datetime import datetime, timezone, timedelta as td
|
||||
|
||||
kst = timezone(td(hours=9))
|
||||
today = datetime.now(kst).date()
|
||||
summary = match_for_date(today)
|
||||
return {
|
||||
"today": str(today),
|
||||
"summary": summary.__dict__,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
out = match_today()
|
||||
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|
||||
Reference in New Issue
Block a user