feat(phase-4): LGBM 모델 + 앙상블 + 매칭/재학습 잡

- backend/app/models/lgbm.py: 종목 × horizon 별 LightGBM 회귀(y_ret_h)
  + 다중분류(y_dir_h, 3-class). joblib 으로 backend/data/models/{code}_h{H}_*.pkl
  저장. early_stopping(30). predict_one() 으로 최신 영업일 피처에 추론.
- backend/app/models/weights.py: ensemble_weights 테이블 IO,
  default w_chronos=0.6 / w_lgbm=0.4 (DB 행 없으면 fallback).
- backend/app/models/ensemble.py: Chronos sample 분포 + LGBM regression+cls
  결합. point/q10/q90 + prob_up/flat/down + direction 라벨. 한쪽 모델
  실패 시 다른 쪽 단독 fallback (cold start: chronos 단독).
- backend/app/pipelines/predict_one.py: predict_and_store(). 결과를
  predictions 테이블에 UPSERT, user_triggered 누적 OR. base_date = 마지막
  ohlcv 거래일, target_date = base_date + H 영업일(주말 스킵, 공휴일은
  매칭잡에서 자연 보정).
- backend/app/pipelines/match_outcomes.py: target_date == d 인
  user_triggered=TRUE 예측을 d 의 실제 종가와 매칭 → prediction_outcomes
  적재. direction_hit(±0.3% flat band) + abs_error. 실제 종가 없으면
  자연 skip.
- backend/app/pipelines/retrain_weekly.py: 시드 10종목 × H 재학습 +
  최근 30일 model_performance 적재.
- backend/app/db/migrations/003_ensemble_weights.sql: (code, horizon) →
  (w_chronos, w_lgbm, hit_rate_*, sample_count).
- backend/app/pipelines/scheduler.py:
    daily_batch    : 평일 16:00 KST
    match_outcomes : 평일 16:30 KST   ← 사용자가 확정한 매칭 시점
    retrain_weekly : 일요일 02:00 KST

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
tkrmagid
2026-05-20 16:03:01 +09:00
parent b1ca6ab5d3
commit bf4fb01146
8 changed files with 896 additions and 2 deletions

View File

@@ -0,0 +1,165 @@
"""prediction_outcomes 매칭 배치.
평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의
확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, target_date == 오늘인
user_triggered=TRUE 예측을 그 종가와 매칭.
cold-start / 휴장일 대비: 인자로 받은 target_date 의 ohlcv_daily 에 종가가
없으면 자연스럽게 skip. 다음 거래일 매칭 잡이 다시 시도하면 그 날짜는
여전히 매칭되지 않으므로 (매칭 sql 이 target_date 기준), 영원히 매칭 안되는
잘못된 calendar date 예측은 cleanup CLI 로 별도 정리 가능 (Phase 7).
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from datetime import date, timedelta
from typing import Any
from sqlalchemy import text
from app.db.connection import get_engine
logger = logging.getLogger(__name__)
# direction_hit 판정 시 ±0.3% 이내는 flat. (features 의 FLAT_BAND 와 동일)
FLAT_BAND = 0.003
@dataclass
class MatchSummary:
target_date: str
candidates: int
matched: int
skipped_no_actual: int
already_resolved: int
def _direction_label(ret: float) -> str:
if ret > FLAT_BAND:
return "up"
if ret < -FLAT_BAND:
return "down"
return "flat"
def match_for_date(d: date) -> MatchSummary:
"""target_date == d 인 user_triggered=TRUE 예측을 매칭."""
eng = get_engine()
with eng.begin() as conn:
# 매칭 대상 예측 + 매칭 안 됐는지 확인.
candidate_rows = conn.execute(
text(
"""
SELECT p.id, p.code, p.base_date, p.horizon, p.point_forecast,
p.direction, p.model
FROM predictions p
LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id
WHERE p.target_date = :d
AND p.user_triggered = TRUE
AND po.prediction_id IS NULL
"""
),
{"d": d},
).all()
candidates = len(candidate_rows)
if not candidates:
return MatchSummary(str(d), 0, 0, 0, 0)
# 종목별로 actual close 조회 (한번에 batch).
codes = list({r[1] for r in candidate_rows})
actual_map: dict[tuple[str, date], float] = {}
for code in codes:
row = conn.execute(
text(
"SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"
),
{"c": code, "d": d},
).first()
if row and row[0] is not None:
actual_map[(code, d)] = float(row[0])
# base_close (각 예측의 base_date 종가) 도 필요 — direction 판정용.
base_close_map: dict[tuple[str, date], float] = {}
for pid, code, base_date, *_ in candidate_rows:
key = (code, base_date)
if key in base_close_map:
continue
row = conn.execute(
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
{"c": code, "d": base_date},
).first()
if row and row[0] is not None:
base_close_map[key] = float(row[0])
matched = 0
skipped = 0
already = 0
for pid, code, base_date, horizon, point_forecast, pred_dir, model in candidate_rows:
actual = actual_map.get((code, d))
base_close = base_close_map.get((code, base_date))
if actual is None or base_close is None:
skipped += 1
continue
actual_ret = actual / base_close - 1.0
actual_dir = _direction_label(actual_ret)
dir_hit = (pred_dir == actual_dir)
abs_err = abs(float(point_forecast) - actual) if point_forecast is not None else None
try:
conn.execute(
text(
"""
INSERT INTO prediction_outcomes
(prediction_id, code, target_date, horizon, model,
predicted_close, actual_close, actual_return, direction_hit, abs_error)
VALUES
(:pid, :code, :d, :h, :m, :pc, :ac, :ar, :dh, :ae)
ON CONFLICT (prediction_id) DO NOTHING
"""
),
{
"pid": pid,
"code": code,
"d": d,
"h": horizon,
"m": model,
"pc": float(point_forecast) if point_forecast is not None else None,
"ac": actual,
"ar": float(actual_ret),
"dh": bool(dir_hit),
"ae": float(abs_err) if abs_err is not None else None,
},
)
matched += 1
except Exception as exc: # noqa: BLE001
logger.warning("match insert failed pid=%s: %s", pid, exc)
already += 1
return MatchSummary(
target_date=str(d),
candidates=candidates,
matched=matched,
skipped_no_actual=skipped,
already_resolved=already,
)
def match_today() -> dict[str, Any]:
"""평일 16:30 KST 호출용. target_date == today (KST) 인 행 매칭."""
from datetime import datetime, timezone, timedelta as td
kst = timezone(td(hours=9))
today = datetime.now(kst).date()
summary = match_for_date(today)
return {
"today": str(today),
"summary": summary.__dict__,
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
out = match_today()
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))

View File

@@ -0,0 +1,138 @@
"""On-demand 예측 + DB 적재.
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
- ensemble.predict() 로 horizons (1,3,5) 결과 계산
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
(대충 calendar 일로 +h * 1.4 — KRX 영업일 추정. Phase 4 단순화: base_date + h 영업일은
ohlcv 상의 다음 h 거래일이 아닌, "거래일 카운트" 대신 단순 calendar+h 로 저장하고
매칭 잡에서 ohlcv_daily 에 그 날짜 행이 있는지로 자연 보정.)
대안 정확도 위해: 매칭 잡은 "예측의 target_date 이 오늘"인 행을 그날 종가와 비교.
calendar date 가 비거래일이면 매칭이 안 되니, 매칭 잡은 매일 실행되어 모일 때 처리.
"""
from __future__ import annotations
import json
import logging
from dataclasses import asdict
from datetime import date, datetime, timedelta, timezone
from typing import Any
from sqlalchemy import text
from app.db.connection import get_engine
from app.models.ensemble import EnsemblePrediction, predict as ensemble_predict
logger = logging.getLogger(__name__)
KST = timezone(timedelta(hours=9))
def _next_trading_target(base_date: date, horizon: int) -> date:
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
d = base_date
added = 0
while added < horizon:
d = d + timedelta(days=1)
if d.weekday() < 5: # 0..4 = Mon..Fri
added += 1
return d
def _last_trading_date(code: str) -> date | None:
eng = get_engine()
with eng.connect() as conn:
row = conn.execute(
text("SELECT MAX(date) FROM ohlcv_daily WHERE code = :c"),
{"c": code},
).first()
return row[0] if row and row[0] else None
def predict_and_store(
code: str,
*,
horizons: tuple[int, ...] = (1, 3, 5),
user_triggered: bool = True,
) -> dict[str, Any]:
"""앙상블 예측 실행 + predictions 테이블 적재. 결과 JSON-serializable dict 반환."""
base_date = _last_trading_date(code)
if base_date is None:
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
now = datetime.now(KST)
eng = get_engine()
saved_ids: list[int] = []
with eng.begin() as conn:
for step in pred.steps:
target_date = _next_trading_target(base_date, step.horizon)
features_snap = {
"base_close": pred.base_close,
"sources_used": pred.sources_used,
"direction": step.direction,
}
row = conn.execute(
text(
"""
INSERT INTO predictions
(code, predicted_at, base_date, target_date, horizon, model,
direction, prob_up, prob_flat, prob_down, expected_return,
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
VALUES
(:code, :predicted_at, :base_date, :target_date, :horizon, 'ensemble',
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
ON CONFLICT (code, base_date, target_date, horizon, model)
DO UPDATE SET
predicted_at = EXCLUDED.predicted_at,
direction = EXCLUDED.direction,
prob_up = EXCLUDED.prob_up,
prob_flat = EXCLUDED.prob_flat,
prob_down = EXCLUDED.prob_down,
expected_return = EXCLUDED.expected_return,
point_forecast = EXCLUDED.point_forecast,
ci_low = EXCLUDED.ci_low,
ci_high = EXCLUDED.ci_high,
features_snapshot = EXCLUDED.features_snapshot,
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
RETURNING id
"""
),
{
"code": code,
"predicted_at": now,
"base_date": base_date,
"target_date": target_date,
"horizon": step.horizon,
"direction": step.direction,
"p_up": step.prob_up,
"p_fl": step.prob_flat,
"p_dn": step.prob_down,
"exp_ret": step.expected_return,
"point": step.point_close,
"lo": step.ci_low,
"hi": step.ci_high,
"feats": json.dumps(features_snap),
"ut": user_triggered,
},
).first()
if row:
saved_ids.append(int(row[0]))
return {
"code": code,
"base_date": str(base_date),
"base_close": pred.base_close,
"sources_used": pred.sources_used,
"steps": [
{
**asdict(s),
"target_date": str(_next_trading_target(base_date, s.horizon)),
}
for s in pred.steps
],
"saved_prediction_ids": saved_ids,
"user_triggered": user_triggered,
}

View File

@@ -0,0 +1,122 @@
"""주간 재학습 + 앙상블 가중치 보정.
일요일 02:00 KST 실행:
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
2. 최근 30일 prediction_outcomes 의 model 별 hit_rate 산출, model_performance 적재.
3. 같은 30일 윈도우에서 chronos vs lgbm hit_rate 로 ensemble_weights 갱신.
(방법: w_chronos = clamp(0.1, hr_c / (hr_c + hr_l), 0.9), w_lgbm = 1 - w_chronos.
hit_rate 데이터가 부족하면 default 0.6/0.4 유지.)
지금은 'ensemble' 모델 단일 종류로 predictions 가 쌓이므로, 가중치 보정은
chronos 단독 시뮬레이션 / lgbm 단독 시뮬레이션 hit_rate 비교가 진정한 방식인데,
Phase 4 단순화: 'ensemble' 의 종합 hit_rate 만 model_performance 에 기록하고
가중치는 default 유지. 진짜 비교는 Phase 7 (chronos 단독 + lgbm 단독 예측을
shadow 로 같이 적재하는 구조) 로 확장.
"""
from __future__ import annotations
import json
import logging
from datetime import date, timedelta
from typing import Any
from sqlalchemy import text
from app.db.connection import get_engine
from app.models.lgbm import train_one
from app.seed.seed_tickers import SEED_TICKERS
logger = logging.getLogger(__name__)
HORIZONS = (1, 3, 5)
WINDOW_DAYS = 30
def retrain_all() -> list[dict[str, Any]]:
"""시드 10종목 × horizons 학습."""
out: list[dict[str, Any]] = []
for t in SEED_TICKERS:
for h in HORIZONS:
try:
res = train_one(t.code, h)
except Exception as exc: # noqa: BLE001
logger.exception("train_one failed for %s h=%d", t.code, h)
res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)}
out.append(res)
return out
def record_performance(as_of: date) -> list[dict[str, Any]]:
"""최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별
hit_rate / mae 산출, model_performance 에 upsert."""
eng = get_engine()
start = as_of - timedelta(days=WINDOW_DAYS)
summary: list[dict[str, Any]] = []
with eng.begin() as conn:
rows = conn.execute(
text(
"""
SELECT code, model, horizon,
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
AVG(abs_error) AS mae,
COUNT(*) AS n
FROM prediction_outcomes
WHERE resolved_at >= :start
GROUP BY code, model, horizon
"""
),
{"start": start},
).all()
for code, model, horizon, hr, mae, n in rows:
conn.execute(
text(
"""
INSERT INTO model_performance
(code, model, window_days, as_of, hit_rate, mae, sample_count)
VALUES (:c, :m, :w, :as_of, :hr, :mae, :n)
ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET
hit_rate = EXCLUDED.hit_rate,
mae = EXCLUDED.mae,
sample_count = EXCLUDED.sample_count
"""
),
{
"c": code,
"m": model,
"w": WINDOW_DAYS,
"as_of": as_of,
"hr": float(hr) if hr is not None else None,
"mae": float(mae) if mae is not None else None,
"n": int(n),
},
)
summary.append(
{
"code": code,
"model": model,
"horizon": horizon,
"hit_rate": float(hr) if hr is not None else None,
"mae": float(mae) if mae is not None else None,
"n": int(n),
}
)
return summary
def run_weekly() -> dict[str, Any]:
"""일요일 02:00 KST 호출 entry-point."""
from datetime import datetime, timezone
kst = timezone(timedelta(hours=9))
as_of = datetime.now(kst).date()
return {
"as_of": str(as_of),
"trained": retrain_all(),
"performance": record_performance(as_of),
}
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
out = run_weekly()
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))

View File

@@ -22,6 +22,8 @@ from apscheduler.triggers.cron import CronTrigger
from pytz import timezone
from app.pipelines.daily_batch import run_daily_batch
from app.pipelines.match_outcomes import match_today
from app.pipelines.retrain_weekly import run_weekly
logger = logging.getLogger(__name__)
KST = timezone("Asia/Seoul")
@@ -34,15 +36,34 @@ def start_scheduler() -> BackgroundScheduler:
if _scheduler:
return _scheduler
_scheduler = BackgroundScheduler(timezone=KST)
# 16:00 평일: 시드 10종목 EOD/뉴스/공시/거시 갱신
_scheduler.add_job(
run_daily_batch,
CronTrigger(hour=16, minute=0, timezone=KST),
CronTrigger(day_of_week="mon-fri", hour=16, minute=0, timezone=KST),
id="daily_batch_16",
replace_existing=True,
max_instances=1,
)
# 16:30 평일: prediction_outcomes 매칭 배치
_scheduler.add_job(
match_today,
CronTrigger(day_of_week="mon-fri", hour=16, minute=30, timezone=KST),
id="match_outcomes_1630",
replace_existing=True,
max_instances=1,
)
# 일요일 02:00: LGBM 재학습 + 성능 기록
_scheduler.add_job(
run_weekly,
CronTrigger(day_of_week="sun", hour=2, minute=0, timezone=KST),
id="retrain_weekly_sun_0200",
replace_existing=True,
max_instances=1,
)
_scheduler.start()
logger.info("scheduler started (daily_batch @ 16:00 KST)")
logger.info(
"scheduler started: daily_batch(16:00 mon-fri), match_outcomes(16:30 mon-fri), retrain_weekly(sun 02:00) KST"
)
return _scheduler