- GET /api/symbols/search?q=...&seed_only= : trigram + prefix + ILIKE 합산 정렬
- GET /api/symbols/{code} : 메타
- GET /api/chart/{code}?days=N&include_* : OHLCV + 일별 감성 + 외인기관거래대금
- POST /api/predict/{code}?horizons=1,3,5 : on-demand 앙상블 예측 + DB 적재
(user_triggered=TRUE)
- GET /api/predict/{code}/latest : 최신 base_date 의 예측 묶음 + base_close
(UI 가 차트 마지막 점에 이어 붙임)
- GET /api/metrics/{code}?window_days=N : 종목 단위 hit_rate / mae (model, horizon 별)
- GET /api/metrics?window_days=N : 전체 누적
- GET /api/news/{code}?source=&limit= : 최신순 뉴스/공시 목록 (감성 점수 포함)
main.py 에 6개 라우터 모두 include.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
138 lines
4.9 KiB
Python
138 lines
4.9 KiB
Python
"""예측 API.
|
|
|
|
POST /api/predict/{code} : on-demand 예측 + predictions 적재 (user_triggered=TRUE)
|
|
GET /api/predict/{code}/latest : 가장 최근 base_date 의 예측 horizons 묶음 반환
|
|
(UI 가 새로고침해도 같은 결과 보여줄 때 사용)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
from sqlalchemy import text
|
|
|
|
from app.db.connection import get_engine
|
|
from app.pipelines.predict_one import predict_and_store
|
|
|
|
logger = logging.getLogger(__name__)
|
|
KST = timezone(timedelta(hours=9))
|
|
|
|
router = APIRouter(prefix="/api/predict", tags=["predict"])
|
|
|
|
|
|
@router.post("/{code}")
|
|
def predict_endpoint(
|
|
code: str,
|
|
horizons: str = Query(default="1,3,5", description="쉼표 구분 정수. ex) '1,3,5'"),
|
|
user_triggered: bool = Query(default=True),
|
|
) -> dict:
|
|
"""on-demand 예측. UI 의 '예상차트 보기' 클릭이 호출."""
|
|
eng = get_engine()
|
|
with eng.connect() as conn:
|
|
row = conn.execute(
|
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
|
{"c": code},
|
|
).first()
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
|
|
|
try:
|
|
hs = tuple(int(x) for x in horizons.split(",") if x.strip())
|
|
if not hs or any(h < 1 or h > 30 for h in hs):
|
|
raise ValueError("invalid horizons")
|
|
except ValueError as exc:
|
|
raise HTTPException(status_code=400, detail=f"bad horizons: {exc}")
|
|
|
|
try:
|
|
result = predict_and_store(code, horizons=hs, user_triggered=user_triggered)
|
|
except RuntimeError as exc:
|
|
raise HTTPException(status_code=409, detail=str(exc))
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.exception("predict_and_store failed for %s", code)
|
|
raise HTTPException(status_code=500, detail=str(exc))
|
|
|
|
return result
|
|
|
|
|
|
@router.get("/{code}/latest")
|
|
def latest_prediction(code: str) -> dict:
|
|
"""가장 최근 base_date 에 대한 'ensemble' 예측 묶음."""
|
|
eng = get_engine()
|
|
with eng.connect() as conn:
|
|
symbol = conn.execute(
|
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
|
{"c": code},
|
|
).first()
|
|
if not symbol:
|
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
|
|
|
latest_base = conn.execute(
|
|
text(
|
|
"SELECT MAX(base_date) FROM predictions "
|
|
"WHERE code = :c AND model = 'ensemble'"
|
|
),
|
|
{"c": code},
|
|
).first()
|
|
if not latest_base or latest_base[0] is None:
|
|
return {"code": code, "found": False, "steps": []}
|
|
base_date = latest_base[0]
|
|
|
|
rows = conn.execute(
|
|
text(
|
|
"""
|
|
SELECT predicted_at, target_date, horizon, direction,
|
|
prob_up, prob_flat, prob_down, expected_return,
|
|
point_forecast, ci_low, ci_high, user_triggered,
|
|
features_snapshot
|
|
FROM predictions
|
|
WHERE code = :c AND base_date = :bd AND model = 'ensemble'
|
|
ORDER BY horizon
|
|
"""
|
|
),
|
|
{"c": code, "bd": base_date},
|
|
).all()
|
|
|
|
# base_close (그 날 ohlcv) 도 같이 반환 — UI 가 차트 마지막 점에 이어붙일 때 사용
|
|
base_close_row = conn.execute(
|
|
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
|
{"c": code, "d": base_date},
|
|
).first()
|
|
base_close = float(base_close_row[0]) if base_close_row and base_close_row[0] is not None else None
|
|
|
|
steps = []
|
|
for r in rows:
|
|
feats = r[12]
|
|
if isinstance(feats, str):
|
|
try:
|
|
feats = json.loads(feats)
|
|
except Exception: # noqa: BLE001
|
|
pass
|
|
steps.append(
|
|
{
|
|
"predicted_at": r[0].isoformat() if r[0] else None,
|
|
"target_date": str(r[1]),
|
|
"horizon": int(r[2]),
|
|
"direction": r[3],
|
|
"prob_up": float(r[4]) if r[4] is not None else None,
|
|
"prob_flat": float(r[5]) if r[5] is not None else None,
|
|
"prob_down": float(r[6]) if r[6] is not None else None,
|
|
"expected_return": float(r[7]) if r[7] is not None else None,
|
|
"point_close": float(r[8]) if r[8] is not None else None,
|
|
"ci_low": float(r[9]) if r[9] is not None else None,
|
|
"ci_high": float(r[10]) if r[10] is not None else None,
|
|
"user_triggered": bool(r[11]),
|
|
"features_snapshot": feats,
|
|
}
|
|
)
|
|
|
|
return {
|
|
"code": symbol[0],
|
|
"name": symbol[1],
|
|
"found": True,
|
|
"base_date": str(base_date),
|
|
"base_close": base_close,
|
|
"steps": steps,
|
|
}
|