"""예측 API. POST /api/predict/{code} : on-demand 예측 + predictions 적재 (user_triggered=TRUE) GET /api/predict/{code}/latest : 가장 최근 base_date 의 예측 horizons 묶음 반환 (UI 가 새로고침해도 같은 결과 보여줄 때 사용) """ from __future__ import annotations import json import logging from datetime import datetime, timedelta, timezone from fastapi import APIRouter, HTTPException, Query from sqlalchemy import text from app.db.connection import get_engine from app.pipelines.predict_one import predict_and_store logger = logging.getLogger(__name__) KST = timezone(timedelta(hours=9)) router = APIRouter(prefix="/api/predict", tags=["predict"]) @router.post("/{code}") def predict_endpoint( code: str, horizons: str = Query(default="1,3,5", description="쉼표 구분 정수. ex) '1,3,5'"), user_triggered: bool = Query(default=True), ) -> dict: """on-demand 예측. UI 의 '예상차트 보기' 클릭이 호출.""" eng = get_engine() with eng.connect() as conn: row = conn.execute( text("SELECT code, name FROM symbols WHERE code = :c"), {"c": code}, ).first() if not row: raise HTTPException(status_code=404, detail=f"unknown code: {code}") try: hs = tuple(int(x) for x in horizons.split(",") if x.strip()) if not hs or any(h < 1 or h > 30 for h in hs): raise ValueError("invalid horizons") except ValueError as exc: raise HTTPException(status_code=400, detail=f"bad horizons: {exc}") try: result = predict_and_store(code, horizons=hs, user_triggered=user_triggered) except RuntimeError as exc: raise HTTPException(status_code=409, detail=str(exc)) except Exception as exc: # noqa: BLE001 logger.exception("predict_and_store failed for %s", code) raise HTTPException(status_code=500, detail=str(exc)) return result @router.get("/{code}/latest") def latest_prediction(code: str) -> dict: """가장 최근 base_date 에 대한 'ensemble' 예측 묶음.""" eng = get_engine() with eng.connect() as conn: symbol = conn.execute( text("SELECT code, name FROM symbols WHERE code = :c"), {"c": code}, ).first() if not symbol: raise HTTPException(status_code=404, detail=f"unknown code: {code}") latest_base = conn.execute( text( "SELECT MAX(base_date) FROM predictions " "WHERE code = :c AND model = 'ensemble'" ), {"c": code}, ).first() if not latest_base or latest_base[0] is None: return {"code": code, "found": False, "steps": []} base_date = latest_base[0] rows = conn.execute( text( """ SELECT predicted_at, target_date, horizon, direction, prob_up, prob_flat, prob_down, expected_return, point_forecast, ci_low, ci_high, user_triggered, features_snapshot FROM predictions WHERE code = :c AND base_date = :bd AND model = 'ensemble' ORDER BY horizon """ ), {"c": code, "bd": base_date}, ).all() # base_close (그 날 ohlcv) 도 같이 반환 — UI 가 차트 마지막 점에 이어붙일 때 사용 base_close_row = conn.execute( text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"), {"c": code, "d": base_date}, ).first() base_close = float(base_close_row[0]) if base_close_row and base_close_row[0] is not None else None steps = [] for r in rows: feats = r[12] if isinstance(feats, str): try: feats = json.loads(feats) except Exception: # noqa: BLE001 pass steps.append( { "predicted_at": r[0].isoformat() if r[0] else None, "target_date": str(r[1]), "horizon": int(r[2]), "direction": r[3], "prob_up": float(r[4]) if r[4] is not None else None, "prob_flat": float(r[5]) if r[5] is not None else None, "prob_down": float(r[6]) if r[6] is not None else None, "expected_return": float(r[7]) if r[7] is not None else None, "point_close": float(r[8]) if r[8] is not None else None, "ci_low": float(r[9]) if r[9] is not None else None, "ci_high": float(r[10]) if r[10] is not None else None, "user_triggered": bool(r[11]), "features_snapshot": feats, } ) return { "code": symbol[0], "name": symbol[1], "found": True, "base_date": str(base_date), "base_close": base_close, "steps": steps, }