feat(phase-5): FastAPI 엔드포인트 (검색/차트/예측/메트릭/뉴스)
- GET /api/symbols/search?q=...&seed_only= : trigram + prefix + ILIKE 합산 정렬
- GET /api/symbols/{code} : 메타
- GET /api/chart/{code}?days=N&include_* : OHLCV + 일별 감성 + 외인기관거래대금
- POST /api/predict/{code}?horizons=1,3,5 : on-demand 앙상블 예측 + DB 적재
(user_triggered=TRUE)
- GET /api/predict/{code}/latest : 최신 base_date 의 예측 묶음 + base_close
(UI 가 차트 마지막 점에 이어 붙임)
- GET /api/metrics/{code}?window_days=N : 종목 단위 hit_rate / mae (model, horizon 별)
- GET /api/metrics?window_days=N : 전체 누적
- GET /api/news/{code}?source=&limit= : 최신순 뉴스/공시 목록 (감성 점수 포함)
main.py 에 6개 라우터 모두 include.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
137
backend/app/api/predict.py
Normal file
137
backend/app/api/predict.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""예측 API.
|
||||
|
||||
POST /api/predict/{code} : on-demand 예측 + predictions 적재 (user_triggered=TRUE)
|
||||
GET /api/predict/{code}/latest : 가장 최근 base_date 의 예측 horizons 묶음 반환
|
||||
(UI 가 새로고침해도 같은 결과 보여줄 때 사용)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.connection import get_engine
|
||||
from app.pipelines.predict_one import predict_and_store
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
KST = timezone(timedelta(hours=9))
|
||||
|
||||
router = APIRouter(prefix="/api/predict", tags=["predict"])
|
||||
|
||||
|
||||
@router.post("/{code}")
|
||||
def predict_endpoint(
|
||||
code: str,
|
||||
horizons: str = Query(default="1,3,5", description="쉼표 구분 정수. ex) '1,3,5'"),
|
||||
user_triggered: bool = Query(default=True),
|
||||
) -> dict:
|
||||
"""on-demand 예측. UI 의 '예상차트 보기' 클릭이 호출."""
|
||||
eng = get_engine()
|
||||
with eng.connect() as conn:
|
||||
row = conn.execute(
|
||||
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||
{"c": code},
|
||||
).first()
|
||||
if not row:
|
||||
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||
|
||||
try:
|
||||
hs = tuple(int(x) for x in horizons.split(",") if x.strip())
|
||||
if not hs or any(h < 1 or h > 30 for h in hs):
|
||||
raise ValueError("invalid horizons")
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=f"bad horizons: {exc}")
|
||||
|
||||
try:
|
||||
result = predict_and_store(code, horizons=hs, user_triggered=user_triggered)
|
||||
except RuntimeError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.exception("predict_and_store failed for %s", code)
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/{code}/latest")
|
||||
def latest_prediction(code: str) -> dict:
|
||||
"""가장 최근 base_date 에 대한 'ensemble' 예측 묶음."""
|
||||
eng = get_engine()
|
||||
with eng.connect() as conn:
|
||||
symbol = conn.execute(
|
||||
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||
{"c": code},
|
||||
).first()
|
||||
if not symbol:
|
||||
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||
|
||||
latest_base = conn.execute(
|
||||
text(
|
||||
"SELECT MAX(base_date) FROM predictions "
|
||||
"WHERE code = :c AND model = 'ensemble'"
|
||||
),
|
||||
{"c": code},
|
||||
).first()
|
||||
if not latest_base or latest_base[0] is None:
|
||||
return {"code": code, "found": False, "steps": []}
|
||||
base_date = latest_base[0]
|
||||
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT predicted_at, target_date, horizon, direction,
|
||||
prob_up, prob_flat, prob_down, expected_return,
|
||||
point_forecast, ci_low, ci_high, user_triggered,
|
||||
features_snapshot
|
||||
FROM predictions
|
||||
WHERE code = :c AND base_date = :bd AND model = 'ensemble'
|
||||
ORDER BY horizon
|
||||
"""
|
||||
),
|
||||
{"c": code, "bd": base_date},
|
||||
).all()
|
||||
|
||||
# base_close (그 날 ohlcv) 도 같이 반환 — UI 가 차트 마지막 점에 이어붙일 때 사용
|
||||
base_close_row = conn.execute(
|
||||
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
||||
{"c": code, "d": base_date},
|
||||
).first()
|
||||
base_close = float(base_close_row[0]) if base_close_row and base_close_row[0] is not None else None
|
||||
|
||||
steps = []
|
||||
for r in rows:
|
||||
feats = r[12]
|
||||
if isinstance(feats, str):
|
||||
try:
|
||||
feats = json.loads(feats)
|
||||
except Exception: # noqa: BLE001
|
||||
pass
|
||||
steps.append(
|
||||
{
|
||||
"predicted_at": r[0].isoformat() if r[0] else None,
|
||||
"target_date": str(r[1]),
|
||||
"horizon": int(r[2]),
|
||||
"direction": r[3],
|
||||
"prob_up": float(r[4]) if r[4] is not None else None,
|
||||
"prob_flat": float(r[5]) if r[5] is not None else None,
|
||||
"prob_down": float(r[6]) if r[6] is not None else None,
|
||||
"expected_return": float(r[7]) if r[7] is not None else None,
|
||||
"point_close": float(r[8]) if r[8] is not None else None,
|
||||
"ci_low": float(r[9]) if r[9] is not None else None,
|
||||
"ci_high": float(r[10]) if r[10] is not None else None,
|
||||
"user_triggered": bool(r[11]),
|
||||
"features_snapshot": feats,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"code": symbol[0],
|
||||
"name": symbol[1],
|
||||
"found": True,
|
||||
"base_date": str(base_date),
|
||||
"base_close": base_close,
|
||||
"steps": steps,
|
||||
}
|
||||
Reference in New Issue
Block a user