diff --git a/backend/app/api/chart.py b/backend/app/api/chart.py new file mode 100644 index 0000000..eb66568 --- /dev/null +++ b/backend/app/api/chart.py @@ -0,0 +1,116 @@ +"""차트 데이터 API: OHLCV + 보조 데이터 (감성, 거시). + +UI: /code 페이지 첫 로드 시 호출 → lightweight-charts 캔들 데이터로 사용. +""" +from __future__ import annotations + +from datetime import date, timedelta + +from fastapi import APIRouter, HTTPException, Query +from sqlalchemy import text + +from app.db.connection import get_engine + +router = APIRouter(prefix="/api/chart", tags=["chart"]) + + +@router.get("/{code}") +def get_chart( + code: str, + days: int = Query(default=180, ge=10, le=3650), + include_sentiment: bool = Query(default=True), + include_trading_value: bool = Query(default=True), +) -> dict: + eng = get_engine() + end = date.today() + start = end - timedelta(days=days) + with eng.connect() as conn: + symbol = conn.execute( + text("SELECT code, name, market FROM symbols WHERE code = :c"), + {"c": code}, + ).first() + if not symbol: + raise HTTPException(status_code=404, detail=f"unknown code: {code}") + + ohlcv_rows = conn.execute( + text( + """ + SELECT date, open, high, low, close, volume + FROM ohlcv_daily + WHERE code = :c AND date BETWEEN :s AND :e + ORDER BY date + """ + ), + {"c": code, "s": start, "e": end}, + ).all() + ohlcv = [ + { + "date": str(r[0]), + "open": float(r[1]) if r[1] is not None else None, + "high": float(r[2]) if r[2] is not None else None, + "low": float(r[3]) if r[3] is not None else None, + "close": float(r[4]) if r[4] is not None else None, + "volume": int(r[5]) if r[5] is not None else None, + } + for r in ohlcv_rows + ] + + sentiment: list[dict] = [] + if include_sentiment: + try: + s_rows = conn.execute( + text( + """ + SELECT date, n_articles, mean_score, weighted_score + FROM v_sentiment_daily + WHERE code = :c AND date BETWEEN :s AND :e + ORDER BY date + """ + ), + {"c": code, "s": start, "e": end}, + ).all() + sentiment = [ + { + "date": str(r[0]), + "n_articles": int(r[1]) if r[1] is not None else 0, + "mean_score": float(r[2]) if r[2] is not None else None, + "weighted_score": float(r[3]) if r[3] is not None else None, + } + for r in s_rows + ] + except Exception: # noqa: BLE001 + # v_sentiment_daily 뷰 아직 없을 수 있음 (마이그레이션 미실행) + sentiment = [] + + trading: list[dict] = [] + if include_trading_value: + tv_rows = conn.execute( + text( + """ + SELECT date, foreign_net, institution_net, individual_net + FROM trading_value_daily + WHERE code = :c AND date BETWEEN :s AND :e + ORDER BY date + """ + ), + {"c": code, "s": start, "e": end}, + ).all() + trading = [ + { + "date": str(r[0]), + "foreign_net": float(r[1]) if r[1] is not None else None, + "institution_net": float(r[2]) if r[2] is not None else None, + "individual_net": float(r[3]) if r[3] is not None else None, + } + for r in tv_rows + ] + + return { + "code": symbol[0], + "name": symbol[1], + "market": symbol[2], + "range": {"from": str(start), "to": str(end)}, + "ohlcv": ohlcv, + "sentiment": sentiment, + "trading_value": trading, + } diff --git a/backend/app/api/metrics.py b/backend/app/api/metrics.py new file mode 100644 index 0000000..f319582 --- /dev/null +++ b/backend/app/api/metrics.py @@ -0,0 +1,101 @@ +"""모델 성능 메트릭 API.""" +from __future__ import annotations + +from datetime import date, timedelta + +from fastapi import APIRouter, HTTPException, Query +from sqlalchemy import text + +from app.db.connection import get_engine + +router = APIRouter(prefix="/api/metrics", tags=["metrics"]) + + +@router.get("/{code}") +def code_metrics( + code: str, + window_days: int = Query(default=30, ge=1, le=365), +) -> dict: + """code 의 최근 window_days 윈도우 prediction_outcomes 집계.""" + eng = get_engine() + end = date.today() + start = end - timedelta(days=window_days) + with eng.connect() as conn: + sym = conn.execute( + text("SELECT code, name FROM symbols WHERE code = :c"), + {"c": code}, + ).first() + if not sym: + raise HTTPException(status_code=404, detail=f"unknown code: {code}") + rows = conn.execute( + text( + """ + SELECT model, horizon, + COUNT(*) AS n, + AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate, + AVG(abs_error) AS mae + FROM prediction_outcomes + WHERE code = :c AND resolved_at >= :s + GROUP BY model, horizon + ORDER BY model, horizon + """ + ), + {"c": code, "s": start}, + ).all() + + return { + "code": sym[0], + "name": sym[1], + "window_days": window_days, + "range": {"from": str(start), "to": str(end)}, + "by_model_horizon": [ + { + "model": r[0], + "horizon": int(r[1]), + "n": int(r[2]), + "hit_rate": float(r[3]) if r[3] is not None else None, + "mae": float(r[4]) if r[4] is not None else None, + } + for r in rows + ], + } + + +@router.get("") +def overall_metrics( + window_days: int = Query(default=30, ge=1, le=365), +) -> dict: + """전체 시드 종목 누적 메트릭.""" + eng = get_engine() + end = date.today() + start = end - timedelta(days=window_days) + with eng.connect() as conn: + rows = conn.execute( + text( + """ + SELECT po.model, po.horizon, + COUNT(*) AS n, + AVG(CASE WHEN po.direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate, + AVG(po.abs_error) AS mae + FROM prediction_outcomes po + WHERE po.resolved_at >= :s + GROUP BY po.model, po.horizon + ORDER BY po.model, po.horizon + """ + ), + {"s": start}, + ).all() + return { + "window_days": window_days, + "range": {"from": str(start), "to": str(end)}, + "by_model_horizon": [ + { + "model": r[0], + "horizon": int(r[1]), + "n": int(r[2]), + "hit_rate": float(r[3]) if r[3] is not None else None, + "mae": float(r[4]) if r[4] is not None else None, + } + for r in rows + ], + } diff --git a/backend/app/api/news.py b/backend/app/api/news.py new file mode 100644 index 0000000..a100c9e --- /dev/null +++ b/backend/app/api/news.py @@ -0,0 +1,58 @@ +"""뉴스/공시 목록 API.""" +from __future__ import annotations + +from fastapi import APIRouter, HTTPException, Query +from sqlalchemy import text + +from app.db.connection import get_engine + +router = APIRouter(prefix="/api/news", tags=["news"]) + + +@router.get("/{code}") +def list_news( + code: str, + limit: int = Query(default=20, ge=1, le=200), + source: str | None = Query(default=None, description="naver_finance / google_rss / dart"), +) -> dict: + eng = get_engine() + with eng.connect() as conn: + sym = conn.execute( + text("SELECT code, name FROM symbols WHERE code = :c"), + {"c": code}, + ).first() + if not sym: + raise HTTPException(status_code=404, detail=f"unknown code: {code}") + where = "code = :c" + params: dict = {"c": code, "lim": limit} + if source: + where += " AND source = :src" + params["src"] = source + rows = conn.execute( + text( + f""" + SELECT source, published_at, title, url, sentiment_score, sentiment_label + FROM news + WHERE {where} + ORDER BY published_at DESC + LIMIT :lim + """ + ), + params, + ).all() + return { + "code": sym[0], + "name": sym[1], + "count": len(rows), + "items": [ + { + "source": r[0], + "published_at": r[1].isoformat() if r[1] else None, + "title": r[2], + "url": r[3], + "sentiment_score": float(r[4]) if r[4] is not None else None, + "sentiment_label": r[5], + } + for r in rows + ], + } diff --git a/backend/app/api/predict.py b/backend/app/api/predict.py new file mode 100644 index 0000000..80d4d35 --- /dev/null +++ b/backend/app/api/predict.py @@ -0,0 +1,137 @@ +"""예측 API. + +POST /api/predict/{code} : on-demand 예측 + predictions 적재 (user_triggered=TRUE) +GET /api/predict/{code}/latest : 가장 최근 base_date 의 예측 horizons 묶음 반환 + (UI 가 새로고침해도 같은 결과 보여줄 때 사용) +""" +from __future__ import annotations + +import json +import logging +from datetime import datetime, timedelta, timezone + +from fastapi import APIRouter, HTTPException, Query +from sqlalchemy import text + +from app.db.connection import get_engine +from app.pipelines.predict_one import predict_and_store + +logger = logging.getLogger(__name__) +KST = timezone(timedelta(hours=9)) + +router = APIRouter(prefix="/api/predict", tags=["predict"]) + + +@router.post("/{code}") +def predict_endpoint( + code: str, + horizons: str = Query(default="1,3,5", description="쉼표 구분 정수. ex) '1,3,5'"), + user_triggered: bool = Query(default=True), +) -> dict: + """on-demand 예측. UI 의 '예상차트 보기' 클릭이 호출.""" + eng = get_engine() + with eng.connect() as conn: + row = conn.execute( + text("SELECT code, name FROM symbols WHERE code = :c"), + {"c": code}, + ).first() + if not row: + raise HTTPException(status_code=404, detail=f"unknown code: {code}") + + try: + hs = tuple(int(x) for x in horizons.split(",") if x.strip()) + if not hs or any(h < 1 or h > 30 for h in hs): + raise ValueError("invalid horizons") + except ValueError as exc: + raise HTTPException(status_code=400, detail=f"bad horizons: {exc}") + + try: + result = predict_and_store(code, horizons=hs, user_triggered=user_triggered) + except RuntimeError as exc: + raise HTTPException(status_code=409, detail=str(exc)) + except Exception as exc: # noqa: BLE001 + logger.exception("predict_and_store failed for %s", code) + raise HTTPException(status_code=500, detail=str(exc)) + + return result + + +@router.get("/{code}/latest") +def latest_prediction(code: str) -> dict: + """가장 최근 base_date 에 대한 'ensemble' 예측 묶음.""" + eng = get_engine() + with eng.connect() as conn: + symbol = conn.execute( + text("SELECT code, name FROM symbols WHERE code = :c"), + {"c": code}, + ).first() + if not symbol: + raise HTTPException(status_code=404, detail=f"unknown code: {code}") + + latest_base = conn.execute( + text( + "SELECT MAX(base_date) FROM predictions " + "WHERE code = :c AND model = 'ensemble'" + ), + {"c": code}, + ).first() + if not latest_base or latest_base[0] is None: + return {"code": code, "found": False, "steps": []} + base_date = latest_base[0] + + rows = conn.execute( + text( + """ + SELECT predicted_at, target_date, horizon, direction, + prob_up, prob_flat, prob_down, expected_return, + point_forecast, ci_low, ci_high, user_triggered, + features_snapshot + FROM predictions + WHERE code = :c AND base_date = :bd AND model = 'ensemble' + ORDER BY horizon + """ + ), + {"c": code, "bd": base_date}, + ).all() + + # base_close (그 날 ohlcv) 도 같이 반환 — UI 가 차트 마지막 점에 이어붙일 때 사용 + base_close_row = conn.execute( + text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"), + {"c": code, "d": base_date}, + ).first() + base_close = float(base_close_row[0]) if base_close_row and base_close_row[0] is not None else None + + steps = [] + for r in rows: + feats = r[12] + if isinstance(feats, str): + try: + feats = json.loads(feats) + except Exception: # noqa: BLE001 + pass + steps.append( + { + "predicted_at": r[0].isoformat() if r[0] else None, + "target_date": str(r[1]), + "horizon": int(r[2]), + "direction": r[3], + "prob_up": float(r[4]) if r[4] is not None else None, + "prob_flat": float(r[5]) if r[5] is not None else None, + "prob_down": float(r[6]) if r[6] is not None else None, + "expected_return": float(r[7]) if r[7] is not None else None, + "point_close": float(r[8]) if r[8] is not None else None, + "ci_low": float(r[9]) if r[9] is not None else None, + "ci_high": float(r[10]) if r[10] is not None else None, + "user_triggered": bool(r[11]), + "features_snapshot": feats, + } + ) + + return { + "code": symbol[0], + "name": symbol[1], + "found": True, + "base_date": str(base_date), + "base_close": base_close, + "steps": steps, + } diff --git a/backend/app/api/symbols.py b/backend/app/api/symbols.py new file mode 100644 index 0000000..15417e1 --- /dev/null +++ b/backend/app/api/symbols.py @@ -0,0 +1,101 @@ +"""종목 검색 / 메타 API.""" +from __future__ import annotations + +from fastapi import APIRouter, HTTPException, Query +from sqlalchemy import text + +from app.db.connection import get_engine + +router = APIRouter(prefix="/api/symbols", tags=["symbols"]) + + +@router.get("/search") +def search_symbols( + q: str = Query(..., min_length=1, max_length=40, description="종목명 또는 코드 prefix/부분 일치"), + limit: int = Query(default=20, ge=1, le=100), + seed_only: bool = Query(default=False, description="true 면 학습/배치 대상 10종목만"), +) -> dict: + """이름은 trigram + ILIKE, 코드는 prefix 매치. + + 우선순위: + 1) 코드가 정확히 같으면 가장 위 + 2) 이름 prefix 매치 + 3) 이름 부분 매치 (trigram similarity) + """ + q_norm = q.strip() + if not q_norm: + raise HTTPException(status_code=400, detail="empty query") + + eng = get_engine() + where_seed = "AND is_seed = TRUE" if seed_only else "" + sql = text( + f""" + WITH ranked AS ( + SELECT code, name, market, sector, is_seed, + CASE + WHEN code = :q THEN 0 + WHEN code LIKE :prefix THEN 1 + WHEN name LIKE :prefix THEN 2 + WHEN name ILIKE :contains THEN 3 + ELSE 4 + END AS rank, + similarity(name, :q) AS sim + FROM symbols + WHERE active = TRUE + {where_seed} + AND (code LIKE :prefix OR name ILIKE :contains OR similarity(name, :q) > 0.2) + ) + SELECT code, name, market, sector, is_seed + FROM ranked + ORDER BY rank ASC, sim DESC, name ASC + LIMIT :lim + """ + ) + with eng.connect() as conn: + rows = conn.execute( + sql, + { + "q": q_norm, + "prefix": f"{q_norm}%", + "contains": f"%{q_norm}%", + "lim": limit, + }, + ).all() + return { + "q": q_norm, + "count": len(rows), + "items": [ + { + "code": r[0], + "name": r[1], + "market": r[2], + "sector": r[3], + "is_seed": bool(r[4]), + } + for r in rows + ], + } + + +@router.get("/{code}") +def get_symbol(code: str) -> dict: + eng = get_engine() + with eng.connect() as conn: + row = conn.execute( + text( + "SELECT code, name, market, sector, is_seed, active, created_at " + "FROM symbols WHERE code = :c" + ), + {"c": code}, + ).first() + if not row: + raise HTTPException(status_code=404, detail=f"unknown code: {code}") + return { + "code": row[0], + "name": row[1], + "market": row[2], + "sector": row[3], + "is_seed": bool(row[4]), + "active": bool(row[5]), + "created_at": str(row[6]) if row[6] else None, + } diff --git a/backend/app/main.py b/backend/app/main.py index b8a840c..8905471 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -6,7 +6,12 @@ from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from app.api.chart import router as chart_router +from app.api.metrics import router as metrics_router +from app.api.news import router as news_router +from app.api.predict import router as predict_router from app.api.refresh import router as refresh_router +from app.api.symbols import router as symbols_router from app.config import settings from app.db.connection import ping as db_ping from app.fetch import dart as dart_mod @@ -41,6 +46,11 @@ app.add_middleware( ) app.include_router(refresh_router) +app.include_router(symbols_router) +app.include_router(chart_router) +app.include_router(predict_router) +app.include_router(metrics_router) +app.include_router(news_router) def _resolved_device() -> str: