- backend/app/models/lgbm.py: 종목 × horizon 별 LightGBM 회귀(y_ret_h)
+ 다중분류(y_dir_h, 3-class). joblib 으로 backend/data/models/{code}_h{H}_*.pkl
저장. early_stopping(30). predict_one() 으로 최신 영업일 피처에 추론.
- backend/app/models/weights.py: ensemble_weights 테이블 IO,
default w_chronos=0.6 / w_lgbm=0.4 (DB 행 없으면 fallback).
- backend/app/models/ensemble.py: Chronos sample 분포 + LGBM regression+cls
결합. point/q10/q90 + prob_up/flat/down + direction 라벨. 한쪽 모델
실패 시 다른 쪽 단독 fallback (cold start: chronos 단독).
- backend/app/pipelines/predict_one.py: predict_and_store(). 결과를
predictions 테이블에 UPSERT, user_triggered 누적 OR. base_date = 마지막
ohlcv 거래일, target_date = base_date + H 영업일(주말 스킵, 공휴일은
매칭잡에서 자연 보정).
- backend/app/pipelines/match_outcomes.py: target_date == d 인
user_triggered=TRUE 예측을 d 의 실제 종가와 매칭 → prediction_outcomes
적재. direction_hit(±0.3% flat band) + abs_error. 실제 종가 없으면
자연 skip.
- backend/app/pipelines/retrain_weekly.py: 시드 10종목 × H 재학습 +
최근 30일 model_performance 적재.
- backend/app/db/migrations/003_ensemble_weights.sql: (code, horizon) →
(w_chronos, w_lgbm, hit_rate_*, sample_count).
- backend/app/pipelines/scheduler.py:
daily_batch : 평일 16:00 KST
match_outcomes : 평일 16:30 KST ← 사용자가 확정한 매칭 시점
retrain_weekly : 일요일 02:00 KST
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
139 lines
5.3 KiB
Python
139 lines
5.3 KiB
Python
"""On-demand 예측 + DB 적재.
|
|
|
|
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
|
|
- ensemble.predict() 로 horizons (1,3,5) 결과 계산
|
|
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
|
|
(대충 calendar 일로 +h * 1.4 — KRX 영업일 추정. Phase 4 단순화: base_date + h 영업일은
|
|
ohlcv 상의 다음 h 거래일이 아닌, "거래일 카운트" 대신 단순 calendar+h 로 저장하고
|
|
매칭 잡에서 ohlcv_daily 에 그 날짜 행이 있는지로 자연 보정.)
|
|
|
|
대안 정확도 위해: 매칭 잡은 "예측의 target_date 이 오늘"인 행을 그날 종가와 비교.
|
|
calendar date 가 비거래일이면 매칭이 안 되니, 매칭 잡은 매일 실행되어 모일 때 처리.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import asdict
|
|
from datetime import date, datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.db.connection import get_engine
|
|
from app.models.ensemble import EnsemblePrediction, predict as ensemble_predict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
KST = timezone(timedelta(hours=9))
|
|
|
|
|
|
def _next_trading_target(base_date: date, horizon: int) -> date:
|
|
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
|
|
d = base_date
|
|
added = 0
|
|
while added < horizon:
|
|
d = d + timedelta(days=1)
|
|
if d.weekday() < 5: # 0..4 = Mon..Fri
|
|
added += 1
|
|
return d
|
|
|
|
|
|
def _last_trading_date(code: str) -> date | None:
|
|
eng = get_engine()
|
|
with eng.connect() as conn:
|
|
row = conn.execute(
|
|
text("SELECT MAX(date) FROM ohlcv_daily WHERE code = :c"),
|
|
{"c": code},
|
|
).first()
|
|
return row[0] if row and row[0] else None
|
|
|
|
|
|
def predict_and_store(
|
|
code: str,
|
|
*,
|
|
horizons: tuple[int, ...] = (1, 3, 5),
|
|
user_triggered: bool = True,
|
|
) -> dict[str, Any]:
|
|
"""앙상블 예측 실행 + predictions 테이블 적재. 결과 JSON-serializable dict 반환."""
|
|
base_date = _last_trading_date(code)
|
|
if base_date is None:
|
|
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
|
|
|
|
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
|
|
now = datetime.now(KST)
|
|
|
|
eng = get_engine()
|
|
saved_ids: list[int] = []
|
|
with eng.begin() as conn:
|
|
for step in pred.steps:
|
|
target_date = _next_trading_target(base_date, step.horizon)
|
|
features_snap = {
|
|
"base_close": pred.base_close,
|
|
"sources_used": pred.sources_used,
|
|
"direction": step.direction,
|
|
}
|
|
row = conn.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO predictions
|
|
(code, predicted_at, base_date, target_date, horizon, model,
|
|
direction, prob_up, prob_flat, prob_down, expected_return,
|
|
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
|
|
VALUES
|
|
(:code, :predicted_at, :base_date, :target_date, :horizon, 'ensemble',
|
|
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
|
|
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
|
|
ON CONFLICT (code, base_date, target_date, horizon, model)
|
|
DO UPDATE SET
|
|
predicted_at = EXCLUDED.predicted_at,
|
|
direction = EXCLUDED.direction,
|
|
prob_up = EXCLUDED.prob_up,
|
|
prob_flat = EXCLUDED.prob_flat,
|
|
prob_down = EXCLUDED.prob_down,
|
|
expected_return = EXCLUDED.expected_return,
|
|
point_forecast = EXCLUDED.point_forecast,
|
|
ci_low = EXCLUDED.ci_low,
|
|
ci_high = EXCLUDED.ci_high,
|
|
features_snapshot = EXCLUDED.features_snapshot,
|
|
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
|
|
RETURNING id
|
|
"""
|
|
),
|
|
{
|
|
"code": code,
|
|
"predicted_at": now,
|
|
"base_date": base_date,
|
|
"target_date": target_date,
|
|
"horizon": step.horizon,
|
|
"direction": step.direction,
|
|
"p_up": step.prob_up,
|
|
"p_fl": step.prob_flat,
|
|
"p_dn": step.prob_down,
|
|
"exp_ret": step.expected_return,
|
|
"point": step.point_close,
|
|
"lo": step.ci_low,
|
|
"hi": step.ci_high,
|
|
"feats": json.dumps(features_snap),
|
|
"ut": user_triggered,
|
|
},
|
|
).first()
|
|
if row:
|
|
saved_ids.append(int(row[0]))
|
|
|
|
return {
|
|
"code": code,
|
|
"base_date": str(base_date),
|
|
"base_close": pred.base_close,
|
|
"sources_used": pred.sources_used,
|
|
"steps": [
|
|
{
|
|
**asdict(s),
|
|
"target_date": str(_next_trading_target(base_date, s.horizon)),
|
|
}
|
|
for s in pred.steps
|
|
],
|
|
"saved_prediction_ids": saved_ids,
|
|
"user_triggered": user_triggered,
|
|
}
|