- ensemble.predict() 가 chronos_raw / lgbm_raw 를 함께 반환
- predict_and_store() 가 매 호출마다 3종 행 적재:
model='ensemble' (user_triggered=인자)
model='chronos' (user_triggered=FALSE, shadow)
model='lgbm' (user_triggered=FALSE, shadow)
- retrain_weekly.adjust_weights(): 최근 30일 prediction_outcomes 의
chronos vs lgbm hit_rate 로 ensemble_weights upsert
w_chronos = clamp(0.1, hr_c/(hr_c+hr_l), 0.9), w_lgbm = 1 - w_chronos
모델별 표본 < 10 이면 기본값(0.6/0.4) 유지
- API 응답에 saved_shadow_ids 추가 (TS 타입도 동기화)
- README: 동작 모델 메모 섹션을 실제 구현과 일치하도록 갱신
리뷰어 지적 3번 (ensemble_weights 가 영원히 갱신 안됨, upsert_weights 미호출) 해결.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
268 lines
9.6 KiB
Python
268 lines
9.6 KiB
Python
"""On-demand 예측 + DB 적재.
|
|
|
|
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
|
|
- ensemble.predict() 로 horizons (1,3,5) 결과 계산
|
|
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
|
|
(주말만 스킵하는 단순 카운트. 공휴일은 match_outcomes 가 "target_date 이후
|
|
최초 거래일 종가"로 자동 이월하여 보정.)
|
|
|
|
세 종류의 행을 함께 저장한다:
|
|
- model='ensemble' : 사용자에게 보여주는 최종 예측. user_triggered 플래그 따라감.
|
|
- model='chronos' : Chronos 단독 (shadow). user_triggered=FALSE 로 항상 적재.
|
|
- model='lgbm' : LGBM 단독 (shadow). user_triggered=FALSE 로 항상 적재.
|
|
|
|
shadow 행은 retrain_weekly 가 모델별 hit_rate 를 비교해 ensemble_weights 를
|
|
자동 보정하는 입력이 된다.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from dataclasses import asdict
|
|
from datetime import date, datetime, timedelta, timezone
|
|
from typing import Any
|
|
|
|
from sqlalchemy import text
|
|
|
|
from app.db.connection import get_engine
|
|
from app.models.ensemble import EnsemblePrediction, predict as ensemble_predict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
KST = timezone(timedelta(hours=9))
|
|
|
|
# ±0.3% flat band — features.FLAT_BAND, match_outcomes.FLAT_BAND 와 동일.
|
|
FLAT_BAND = 0.003
|
|
|
|
|
|
def _next_trading_target(base_date: date, horizon: int) -> date:
|
|
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
|
|
d = base_date
|
|
added = 0
|
|
while added < horizon:
|
|
d = d + timedelta(days=1)
|
|
if d.weekday() < 5: # 0..4 = Mon..Fri
|
|
added += 1
|
|
return d
|
|
|
|
|
|
def _last_trading_date(code: str) -> date | None:
|
|
eng = get_engine()
|
|
with eng.connect() as conn:
|
|
row = conn.execute(
|
|
text("SELECT MAX(date) FROM ohlcv_daily WHERE code = :c"),
|
|
{"c": code},
|
|
).first()
|
|
return row[0] if row and row[0] else None
|
|
|
|
|
|
def _direction_label(ret: float) -> str:
|
|
if ret > FLAT_BAND:
|
|
return "up"
|
|
if ret < -FLAT_BAND:
|
|
return "down"
|
|
return "flat"
|
|
|
|
|
|
_INSERT_PREDICTION_SQL = text(
|
|
"""
|
|
INSERT INTO predictions
|
|
(code, predicted_at, base_date, target_date, horizon, model,
|
|
direction, prob_up, prob_flat, prob_down, expected_return,
|
|
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
|
|
VALUES
|
|
(:code, :predicted_at, :base_date, :target_date, :horizon, :model,
|
|
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
|
|
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
|
|
ON CONFLICT (code, base_date, target_date, horizon, model)
|
|
DO UPDATE SET
|
|
predicted_at = EXCLUDED.predicted_at,
|
|
direction = EXCLUDED.direction,
|
|
prob_up = EXCLUDED.prob_up,
|
|
prob_flat = EXCLUDED.prob_flat,
|
|
prob_down = EXCLUDED.prob_down,
|
|
expected_return = EXCLUDED.expected_return,
|
|
point_forecast = EXCLUDED.point_forecast,
|
|
ci_low = EXCLUDED.ci_low,
|
|
ci_high = EXCLUDED.ci_high,
|
|
features_snapshot = EXCLUDED.features_snapshot,
|
|
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
|
|
RETURNING id
|
|
"""
|
|
)
|
|
|
|
|
|
def _insert_prediction(conn, *, model: str, code: str, predicted_at: datetime,
|
|
base_date: date, target_date: date, horizon: int,
|
|
direction: str, p_up: float, p_fl: float, p_dn: float,
|
|
expected_return: float, point: float | None,
|
|
lo: float | None, hi: float | None,
|
|
features_snap: dict, user_triggered: bool) -> int | None:
|
|
row = conn.execute(
|
|
_INSERT_PREDICTION_SQL,
|
|
{
|
|
"code": code,
|
|
"predicted_at": predicted_at,
|
|
"base_date": base_date,
|
|
"target_date": target_date,
|
|
"horizon": horizon,
|
|
"model": model,
|
|
"direction": direction,
|
|
"p_up": p_up,
|
|
"p_fl": p_fl,
|
|
"p_dn": p_dn,
|
|
"exp_ret": expected_return,
|
|
"point": point,
|
|
"lo": lo,
|
|
"hi": hi,
|
|
"feats": json.dumps(features_snap),
|
|
"ut": user_triggered,
|
|
},
|
|
).first()
|
|
return int(row[0]) if row else None
|
|
|
|
|
|
def predict_and_store(
|
|
code: str,
|
|
*,
|
|
horizons: tuple[int, ...] = (1, 3, 5),
|
|
user_triggered: bool = True,
|
|
) -> dict[str, Any]:
|
|
"""앙상블 예측 실행 + predictions 테이블 적재.
|
|
|
|
적재 행:
|
|
- 'ensemble' (user_triggered 인자 반영)
|
|
- 'chronos' (shadow, user_triggered=FALSE) — Chronos 가 성공했을 때만
|
|
- 'lgbm' (shadow, user_triggered=FALSE) — LGBM 이 성공한 horizon 만
|
|
"""
|
|
base_date = _last_trading_date(code)
|
|
if base_date is None:
|
|
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
|
|
|
|
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
|
|
now = datetime.now(KST)
|
|
base_close = pred.base_close
|
|
|
|
eng = get_engine()
|
|
saved_ids: dict[str, list[int]] = {"ensemble": [], "chronos": [], "lgbm": []}
|
|
with eng.begin() as conn:
|
|
for step in pred.steps:
|
|
target_date = _next_trading_target(base_date, step.horizon)
|
|
|
|
# --- ensemble row ---
|
|
features_snap = {
|
|
"base_close": base_close,
|
|
"sources_used": pred.sources_used,
|
|
"direction": step.direction,
|
|
}
|
|
pid = _insert_prediction(
|
|
conn,
|
|
model="ensemble",
|
|
code=code,
|
|
predicted_at=now,
|
|
base_date=base_date,
|
|
target_date=target_date,
|
|
horizon=step.horizon,
|
|
direction=step.direction,
|
|
p_up=step.prob_up,
|
|
p_fl=step.prob_flat,
|
|
p_dn=step.prob_down,
|
|
expected_return=step.expected_return,
|
|
point=step.point_close,
|
|
lo=step.ci_low,
|
|
hi=step.ci_high,
|
|
features_snap=features_snap,
|
|
user_triggered=user_triggered,
|
|
)
|
|
if pid is not None:
|
|
saved_ids["ensemble"].append(pid)
|
|
|
|
# --- chronos shadow row ---
|
|
cf = pred.chronos_raw
|
|
if cf is not None:
|
|
c_med = float(cf.median[step.horizon - 1])
|
|
c_q10 = float(cf.q10[step.horizon - 1])
|
|
c_q90 = float(cf.q90[step.horizon - 1])
|
|
# direction prob: chronos sample 분포
|
|
try:
|
|
import numpy as np
|
|
arr = np.array(cf.samples)[:, step.horizon - 1]
|
|
ret = arr / base_close - 1.0
|
|
cp_up = float((ret > FLAT_BAND).mean())
|
|
cp_dn = float((ret < -FLAT_BAND).mean())
|
|
cp_fl = max(0.0, 1.0 - cp_up - cp_dn)
|
|
except Exception: # noqa: BLE001
|
|
cp_up = cp_fl = cp_dn = 1.0 / 3.0
|
|
exp_ret_c = c_med / base_close - 1.0
|
|
c_dir = _direction_label(exp_ret_c)
|
|
pid_c = _insert_prediction(
|
|
conn,
|
|
model="chronos",
|
|
code=code,
|
|
predicted_at=now,
|
|
base_date=base_date,
|
|
target_date=target_date,
|
|
horizon=step.horizon,
|
|
direction=c_dir,
|
|
p_up=cp_up,
|
|
p_fl=cp_fl,
|
|
p_dn=cp_dn,
|
|
expected_return=exp_ret_c,
|
|
point=c_med,
|
|
lo=c_q10,
|
|
hi=c_q90,
|
|
features_snap={"shadow": True, "base_close": base_close},
|
|
user_triggered=False,
|
|
)
|
|
if pid_c is not None:
|
|
saved_ids["chronos"].append(pid_c)
|
|
|
|
# --- lgbm shadow row ---
|
|
lf = pred.lgbm_raw.get(step.horizon)
|
|
if lf is not None:
|
|
l_close = float(lf.predicted_close)
|
|
exp_ret_l = l_close / base_close - 1.0
|
|
l_dir = _direction_label(exp_ret_l)
|
|
pid_l = _insert_prediction(
|
|
conn,
|
|
model="lgbm",
|
|
code=code,
|
|
predicted_at=now,
|
|
base_date=base_date,
|
|
target_date=target_date,
|
|
horizon=step.horizon,
|
|
direction=l_dir,
|
|
p_up=float(lf.prob_up),
|
|
p_fl=float(lf.prob_flat),
|
|
p_dn=float(lf.prob_down),
|
|
expected_return=exp_ret_l,
|
|
point=l_close,
|
|
lo=l_close * 0.97,
|
|
hi=l_close * 1.03,
|
|
features_snap={"shadow": True, "base_close": base_close},
|
|
user_triggered=False,
|
|
)
|
|
if pid_l is not None:
|
|
saved_ids["lgbm"].append(pid_l)
|
|
|
|
return {
|
|
"code": code,
|
|
"base_date": str(base_date),
|
|
"base_close": pred.base_close,
|
|
"sources_used": pred.sources_used,
|
|
"steps": [
|
|
{
|
|
**asdict(s),
|
|
"target_date": str(_next_trading_target(base_date, s.horizon)),
|
|
}
|
|
for s in pred.steps
|
|
],
|
|
# UI 는 ensemble id 만 본다. shadow 는 디버깅/검증용으로 별도 키.
|
|
"saved_prediction_ids": saved_ids["ensemble"],
|
|
"saved_shadow_ids": {
|
|
"chronos": saved_ids["chronos"],
|
|
"lgbm": saved_ids["lgbm"],
|
|
},
|
|
"user_triggered": user_triggered,
|
|
}
|