"""On-demand 예측 + DB 적재. POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점. - ensemble.predict() 로 horizons (1,3,5) 결과 계산 - base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일 (주말만 스킵하는 단순 카운트. 공휴일은 match_outcomes 가 "target_date 이후 최초 거래일 종가"로 자동 이월하여 보정.) 세 종류의 행을 함께 저장한다: - model='ensemble' : 사용자에게 보여주는 최종 예측. user_triggered 플래그 따라감. - model='chronos' : Chronos 단독 (shadow). user_triggered=FALSE 로 항상 적재. - model='lgbm' : LGBM 단독 (shadow). user_triggered=FALSE 로 항상 적재. shadow 행은 retrain_weekly 가 모델별 hit_rate 를 비교해 ensemble_weights 를 자동 보정하는 입력이 된다. """ from __future__ import annotations import json import logging from dataclasses import asdict from datetime import date, datetime, timedelta, timezone from typing import Any from sqlalchemy import text from app.db.connection import get_engine from app.models.ensemble import EnsemblePrediction, predict as ensemble_predict logger = logging.getLogger(__name__) KST = timezone(timedelta(hours=9)) # ±0.3% flat band — features.FLAT_BAND, match_outcomes.FLAT_BAND 와 동일. FLAT_BAND = 0.003 def _next_trading_target(base_date: date, horizon: int) -> date: """base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정).""" d = base_date added = 0 while added < horizon: d = d + timedelta(days=1) if d.weekday() < 5: # 0..4 = Mon..Fri added += 1 return d def _last_trading_date(code: str) -> date | None: eng = get_engine() with eng.connect() as conn: row = conn.execute( text("SELECT MAX(date) FROM ohlcv_daily WHERE code = :c"), {"c": code}, ).first() return row[0] if row and row[0] else None def _direction_label(ret: float) -> str: if ret > FLAT_BAND: return "up" if ret < -FLAT_BAND: return "down" return "flat" _INSERT_PREDICTION_SQL = text( """ INSERT INTO predictions (code, predicted_at, base_date, target_date, horizon, model, direction, prob_up, prob_flat, prob_down, expected_return, point_forecast, ci_low, ci_high, features_snapshot, user_triggered) VALUES (:code, :predicted_at, :base_date, :target_date, :horizon, :model, :direction, :p_up, :p_fl, :p_dn, :exp_ret, :point, :lo, :hi, CAST(:feats AS JSONB), :ut) ON CONFLICT (code, base_date, target_date, horizon, model) DO UPDATE SET predicted_at = EXCLUDED.predicted_at, direction = EXCLUDED.direction, prob_up = EXCLUDED.prob_up, prob_flat = EXCLUDED.prob_flat, prob_down = EXCLUDED.prob_down, expected_return = EXCLUDED.expected_return, point_forecast = EXCLUDED.point_forecast, ci_low = EXCLUDED.ci_low, ci_high = EXCLUDED.ci_high, features_snapshot = EXCLUDED.features_snapshot, user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered RETURNING id """ ) def _insert_prediction(conn, *, model: str, code: str, predicted_at: datetime, base_date: date, target_date: date, horizon: int, direction: str, p_up: float, p_fl: float, p_dn: float, expected_return: float, point: float | None, lo: float | None, hi: float | None, features_snap: dict, user_triggered: bool) -> int | None: row = conn.execute( _INSERT_PREDICTION_SQL, { "code": code, "predicted_at": predicted_at, "base_date": base_date, "target_date": target_date, "horizon": horizon, "model": model, "direction": direction, "p_up": p_up, "p_fl": p_fl, "p_dn": p_dn, "exp_ret": expected_return, "point": point, "lo": lo, "hi": hi, "feats": json.dumps(features_snap), "ut": user_triggered, }, ).first() return int(row[0]) if row else None def predict_and_store( code: str, *, horizons: tuple[int, ...] = (1, 3, 5), user_triggered: bool = True, ) -> dict[str, Any]: """앙상블 예측 실행 + predictions 테이블 적재. 적재 행: - 'ensemble' (user_triggered 인자 반영) - 'chronos' (shadow, user_triggered=FALSE) — Chronos 가 성공했을 때만 - 'lgbm' (shadow, user_triggered=FALSE) — LGBM 이 성공한 horizon 만 """ base_date = _last_trading_date(code) if base_date is None: raise RuntimeError(f"no ohlcv_daily for {code}; refresh first") pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons) now = datetime.now(KST) base_close = pred.base_close eng = get_engine() saved_ids: dict[str, list[int]] = {"ensemble": [], "chronos": [], "lgbm": []} with eng.begin() as conn: for step in pred.steps: target_date = _next_trading_target(base_date, step.horizon) # --- ensemble row --- features_snap = { "base_close": base_close, "sources_used": pred.sources_used, "direction": step.direction, } pid = _insert_prediction( conn, model="ensemble", code=code, predicted_at=now, base_date=base_date, target_date=target_date, horizon=step.horizon, direction=step.direction, p_up=step.prob_up, p_fl=step.prob_flat, p_dn=step.prob_down, expected_return=step.expected_return, point=step.point_close, lo=step.ci_low, hi=step.ci_high, features_snap=features_snap, user_triggered=user_triggered, ) if pid is not None: saved_ids["ensemble"].append(pid) # --- chronos shadow row --- cf = pred.chronos_raw if cf is not None: c_med = float(cf.median[step.horizon - 1]) c_q10 = float(cf.q10[step.horizon - 1]) c_q90 = float(cf.q90[step.horizon - 1]) # direction prob: chronos sample 분포 try: import numpy as np arr = np.array(cf.samples)[:, step.horizon - 1] ret = arr / base_close - 1.0 cp_up = float((ret > FLAT_BAND).mean()) cp_dn = float((ret < -FLAT_BAND).mean()) cp_fl = max(0.0, 1.0 - cp_up - cp_dn) except Exception: # noqa: BLE001 cp_up = cp_fl = cp_dn = 1.0 / 3.0 exp_ret_c = c_med / base_close - 1.0 c_dir = _direction_label(exp_ret_c) pid_c = _insert_prediction( conn, model="chronos", code=code, predicted_at=now, base_date=base_date, target_date=target_date, horizon=step.horizon, direction=c_dir, p_up=cp_up, p_fl=cp_fl, p_dn=cp_dn, expected_return=exp_ret_c, point=c_med, lo=c_q10, hi=c_q90, features_snap={"shadow": True, "base_close": base_close}, user_triggered=False, ) if pid_c is not None: saved_ids["chronos"].append(pid_c) # --- lgbm shadow row --- lf = pred.lgbm_raw.get(step.horizon) if lf is not None: l_close = float(lf.predicted_close) exp_ret_l = l_close / base_close - 1.0 l_dir = _direction_label(exp_ret_l) pid_l = _insert_prediction( conn, model="lgbm", code=code, predicted_at=now, base_date=base_date, target_date=target_date, horizon=step.horizon, direction=l_dir, p_up=float(lf.prob_up), p_fl=float(lf.prob_flat), p_dn=float(lf.prob_down), expected_return=exp_ret_l, point=l_close, lo=l_close * 0.97, hi=l_close * 1.03, features_snap={"shadow": True, "base_close": base_close}, user_triggered=False, ) if pid_l is not None: saved_ids["lgbm"].append(pid_l) return { "code": code, "base_date": str(base_date), "base_close": pred.base_close, "sources_used": pred.sources_used, "steps": [ { **asdict(s), "target_date": str(_next_trading_target(base_date, s.horizon)), } for s in pred.steps ], # UI 는 ensemble id 만 본다. shadow 는 디버깅/검증용으로 별도 키. "saved_prediction_ids": saved_ids["ensemble"], "saved_shadow_ids": { "chronos": saved_ids["chronos"], "lgbm": saved_ids["lgbm"], }, "user_triggered": user_triggered, }