Compare commits
6 Commits
239b104a2b
...
bc016ab76d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc016ab76d | ||
|
|
4fb6cec383 | ||
|
|
41ee9d5bb0 | ||
|
|
bf4fb01146 | ||
|
|
b1ca6ab5d3 | ||
|
|
edda01adbf |
32
README.md
32
README.md
@@ -129,15 +129,29 @@ stock_chart_site/
|
|||||||
|
|
||||||
## 진행 계획
|
## 진행 계획
|
||||||
|
|
||||||
- Phase 0 — 스캐폴드 (현재): Docker 환경 + DB 스키마 + 빈 FastAPI/Next.js + build.bat
|
- [x] Phase 0 — 스캐폴드: Docker 환경 + DB 스키마 + FastAPI/Next.js + build.bat
|
||||||
- Phase 1a — pykrx 데이터 파이프: 일봉/외인기관/지수 + DART + 뉴스 RSS + 거시
|
- [x] Phase 1a — pykrx 데이터 파이프: 일봉/외인기관/지수 + DART + 뉴스 RSS + 거시
|
||||||
- Phase 1b — KIS EOD (키 받으면)
|
- [x] Phase 1b — KIS read-only EOD (스모크 통과)
|
||||||
- Phase 2 — KR-FinBERT 감성 점수 + 일별 집계
|
- [x] Phase 2 — KR-FinBERT 감성 점수 + 일별 집계 뷰
|
||||||
- Phase 3 — Chronos zero-shot 예측 적재
|
- [x] Phase 3 — Chronos zero-shot 예측 어댑터 + 피처 빌더
|
||||||
- Phase 4 — LightGBM walk-forward + `prediction_outcomes` 누적 시작
|
- [x] Phase 4 — LightGBM walk-forward + ensemble + 매칭/재학습 잡
|
||||||
- Phase 5 — FastAPI 엔드포인트 (검색, 차트, on-demand 예측, 메트릭)
|
- [x] Phase 5 — FastAPI 엔드포인트 (검색/차트/예측/메트릭/뉴스)
|
||||||
- Phase 6 — Next.js UI (검색 + 현재 차트 + 예상차트 토글)
|
- [x] Phase 6 — Next.js UI (검색 + 현재 차트 + 예상차트 overlay)
|
||||||
- Phase 7 (옵션) — 백테스트 페이지 + 주간 자동 재학습
|
- [ ] Phase 7 (옵션) — 백테스트 페이지 + Chronos/LGBM 단독 shadow 예측
|
||||||
|
|
||||||
|
### API 엔드포인트 (요약)
|
||||||
|
|
||||||
|
| 메서드 | 경로 | 설명 |
|
||||||
|
|---|---|---|
|
||||||
|
| GET | `/health`, `/health/db`, `/health/keys` | 헬스/외부키 ping |
|
||||||
|
| POST | `/api/refresh/{code}?lookback_days=N` | 수동 갱신 |
|
||||||
|
| GET | `/api/symbols/search?q=&seed_only=` | 종목 검색 (trigram + prefix) |
|
||||||
|
| GET | `/api/symbols/{code}` | 종목 메타 |
|
||||||
|
| GET | `/api/chart/{code}?days=N` | OHLCV + 감성 + 외인기관거래대금 |
|
||||||
|
| POST | `/api/predict/{code}?horizons=1,3,5` | on-demand 앙상블 예측 (user_triggered) |
|
||||||
|
| GET | `/api/predict/{code}/latest` | 최신 base_date 예측 묶음 (UI overlay) |
|
||||||
|
| GET | `/api/metrics/{code}?window_days=N` | 종목 hit_rate / mae |
|
||||||
|
| GET | `/api/news/{code}?limit=N&source=` | 최근 뉴스/공시 + 감성 |
|
||||||
|
|
||||||
## 동작 모델 메모
|
## 동작 모델 메모
|
||||||
|
|
||||||
|
|||||||
116
backend/app/api/chart.py
Normal file
116
backend/app/api/chart.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""차트 데이터 API: OHLCV + 보조 데이터 (감성, 거시).
|
||||||
|
|
||||||
|
UI: /code 페이지 첫 로드 시 호출 → lightweight-charts 캔들 데이터로 사용.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/chart", tags=["chart"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}")
|
||||||
|
def get_chart(
|
||||||
|
code: str,
|
||||||
|
days: int = Query(default=180, ge=10, le=3650),
|
||||||
|
include_sentiment: bool = Query(default=True),
|
||||||
|
include_trading_value: bool = Query(default=True),
|
||||||
|
) -> dict:
|
||||||
|
eng = get_engine()
|
||||||
|
end = date.today()
|
||||||
|
start = end - timedelta(days=days)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
symbol = conn.execute(
|
||||||
|
text("SELECT code, name, market FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not symbol:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
|
||||||
|
ohlcv_rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT date, open, high, low, close, volume
|
||||||
|
FROM ohlcv_daily
|
||||||
|
WHERE code = :c AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start, "e": end},
|
||||||
|
).all()
|
||||||
|
ohlcv = [
|
||||||
|
{
|
||||||
|
"date": str(r[0]),
|
||||||
|
"open": float(r[1]) if r[1] is not None else None,
|
||||||
|
"high": float(r[2]) if r[2] is not None else None,
|
||||||
|
"low": float(r[3]) if r[3] is not None else None,
|
||||||
|
"close": float(r[4]) if r[4] is not None else None,
|
||||||
|
"volume": int(r[5]) if r[5] is not None else None,
|
||||||
|
}
|
||||||
|
for r in ohlcv_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
sentiment: list[dict] = []
|
||||||
|
if include_sentiment:
|
||||||
|
try:
|
||||||
|
s_rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT date, n_articles, mean_score, weighted_score
|
||||||
|
FROM v_sentiment_daily
|
||||||
|
WHERE code = :c AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start, "e": end},
|
||||||
|
).all()
|
||||||
|
sentiment = [
|
||||||
|
{
|
||||||
|
"date": str(r[0]),
|
||||||
|
"n_articles": int(r[1]) if r[1] is not None else 0,
|
||||||
|
"mean_score": float(r[2]) if r[2] is not None else None,
|
||||||
|
"weighted_score": float(r[3]) if r[3] is not None else None,
|
||||||
|
}
|
||||||
|
for r in s_rows
|
||||||
|
]
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
# v_sentiment_daily 뷰 아직 없을 수 있음 (마이그레이션 미실행)
|
||||||
|
sentiment = []
|
||||||
|
|
||||||
|
trading: list[dict] = []
|
||||||
|
if include_trading_value:
|
||||||
|
tv_rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT date, foreign_net, institution_net, individual_net
|
||||||
|
FROM trading_value_daily
|
||||||
|
WHERE code = :c AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start, "e": end},
|
||||||
|
).all()
|
||||||
|
trading = [
|
||||||
|
{
|
||||||
|
"date": str(r[0]),
|
||||||
|
"foreign_net": float(r[1]) if r[1] is not None else None,
|
||||||
|
"institution_net": float(r[2]) if r[2] is not None else None,
|
||||||
|
"individual_net": float(r[3]) if r[3] is not None else None,
|
||||||
|
}
|
||||||
|
for r in tv_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": symbol[0],
|
||||||
|
"name": symbol[1],
|
||||||
|
"market": symbol[2],
|
||||||
|
"range": {"from": str(start), "to": str(end)},
|
||||||
|
"ohlcv": ohlcv,
|
||||||
|
"sentiment": sentiment,
|
||||||
|
"trading_value": trading,
|
||||||
|
}
|
||||||
101
backend/app/api/metrics.py
Normal file
101
backend/app/api/metrics.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""모델 성능 메트릭 API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/metrics", tags=["metrics"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}")
|
||||||
|
def code_metrics(
|
||||||
|
code: str,
|
||||||
|
window_days: int = Query(default=30, ge=1, le=365),
|
||||||
|
) -> dict:
|
||||||
|
"""code 의 최근 window_days 윈도우 prediction_outcomes 집계."""
|
||||||
|
eng = get_engine()
|
||||||
|
end = date.today()
|
||||||
|
start = end - timedelta(days=window_days)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
sym = conn.execute(
|
||||||
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not sym:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT model, horizon,
|
||||||
|
COUNT(*) AS n,
|
||||||
|
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||||
|
AVG(abs_error) AS mae
|
||||||
|
FROM prediction_outcomes
|
||||||
|
WHERE code = :c AND resolved_at >= :s
|
||||||
|
GROUP BY model, horizon
|
||||||
|
ORDER BY model, horizon
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": sym[0],
|
||||||
|
"name": sym[1],
|
||||||
|
"window_days": window_days,
|
||||||
|
"range": {"from": str(start), "to": str(end)},
|
||||||
|
"by_model_horizon": [
|
||||||
|
{
|
||||||
|
"model": r[0],
|
||||||
|
"horizon": int(r[1]),
|
||||||
|
"n": int(r[2]),
|
||||||
|
"hit_rate": float(r[3]) if r[3] is not None else None,
|
||||||
|
"mae": float(r[4]) if r[4] is not None else None,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
def overall_metrics(
|
||||||
|
window_days: int = Query(default=30, ge=1, le=365),
|
||||||
|
) -> dict:
|
||||||
|
"""전체 시드 종목 누적 메트릭."""
|
||||||
|
eng = get_engine()
|
||||||
|
end = date.today()
|
||||||
|
start = end - timedelta(days=window_days)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT po.model, po.horizon,
|
||||||
|
COUNT(*) AS n,
|
||||||
|
AVG(CASE WHEN po.direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||||
|
AVG(po.abs_error) AS mae
|
||||||
|
FROM prediction_outcomes po
|
||||||
|
WHERE po.resolved_at >= :s
|
||||||
|
GROUP BY po.model, po.horizon
|
||||||
|
ORDER BY po.model, po.horizon
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"s": start},
|
||||||
|
).all()
|
||||||
|
return {
|
||||||
|
"window_days": window_days,
|
||||||
|
"range": {"from": str(start), "to": str(end)},
|
||||||
|
"by_model_horizon": [
|
||||||
|
{
|
||||||
|
"model": r[0],
|
||||||
|
"horizon": int(r[1]),
|
||||||
|
"n": int(r[2]),
|
||||||
|
"hit_rate": float(r[3]) if r[3] is not None else None,
|
||||||
|
"mae": float(r[4]) if r[4] is not None else None,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
58
backend/app/api/news.py
Normal file
58
backend/app/api/news.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""뉴스/공시 목록 API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/news", tags=["news"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}")
|
||||||
|
def list_news(
|
||||||
|
code: str,
|
||||||
|
limit: int = Query(default=20, ge=1, le=200),
|
||||||
|
source: str | None = Query(default=None, description="naver_finance / google_rss / dart"),
|
||||||
|
) -> dict:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
sym = conn.execute(
|
||||||
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not sym:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
where = "code = :c"
|
||||||
|
params: dict = {"c": code, "lim": limit}
|
||||||
|
if source:
|
||||||
|
where += " AND source = :src"
|
||||||
|
params["src"] = source
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
SELECT source, published_at, title, url, sentiment_score, sentiment_label
|
||||||
|
FROM news
|
||||||
|
WHERE {where}
|
||||||
|
ORDER BY published_at DESC
|
||||||
|
LIMIT :lim
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
).all()
|
||||||
|
return {
|
||||||
|
"code": sym[0],
|
||||||
|
"name": sym[1],
|
||||||
|
"count": len(rows),
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"source": r[0],
|
||||||
|
"published_at": r[1].isoformat() if r[1] else None,
|
||||||
|
"title": r[2],
|
||||||
|
"url": r[3],
|
||||||
|
"sentiment_score": float(r[4]) if r[4] is not None else None,
|
||||||
|
"sentiment_label": r[5],
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
137
backend/app/api/predict.py
Normal file
137
backend/app/api/predict.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""예측 API.
|
||||||
|
|
||||||
|
POST /api/predict/{code} : on-demand 예측 + predictions 적재 (user_triggered=TRUE)
|
||||||
|
GET /api/predict/{code}/latest : 가장 최근 base_date 의 예측 horizons 묶음 반환
|
||||||
|
(UI 가 새로고침해도 같은 결과 보여줄 때 사용)
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
from app.pipelines.predict_one import predict_and_store
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
KST = timezone(timedelta(hours=9))
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/predict", tags=["predict"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{code}")
|
||||||
|
def predict_endpoint(
|
||||||
|
code: str,
|
||||||
|
horizons: str = Query(default="1,3,5", description="쉼표 구분 정수. ex) '1,3,5'"),
|
||||||
|
user_triggered: bool = Query(default=True),
|
||||||
|
) -> dict:
|
||||||
|
"""on-demand 예측. UI 의 '예상차트 보기' 클릭이 호출."""
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
hs = tuple(int(x) for x in horizons.split(",") if x.strip())
|
||||||
|
if not hs or any(h < 1 or h > 30 for h in hs):
|
||||||
|
raise ValueError("invalid horizons")
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=f"bad horizons: {exc}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = predict_and_store(code, horizons=hs, user_triggered=user_triggered)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
raise HTTPException(status_code=409, detail=str(exc))
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.exception("predict_and_store failed for %s", code)
|
||||||
|
raise HTTPException(status_code=500, detail=str(exc))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}/latest")
|
||||||
|
def latest_prediction(code: str) -> dict:
|
||||||
|
"""가장 최근 base_date 에 대한 'ensemble' 예측 묶음."""
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
symbol = conn.execute(
|
||||||
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not symbol:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
|
||||||
|
latest_base = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT MAX(base_date) FROM predictions "
|
||||||
|
"WHERE code = :c AND model = 'ensemble'"
|
||||||
|
),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not latest_base or latest_base[0] is None:
|
||||||
|
return {"code": code, "found": False, "steps": []}
|
||||||
|
base_date = latest_base[0]
|
||||||
|
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT predicted_at, target_date, horizon, direction,
|
||||||
|
prob_up, prob_flat, prob_down, expected_return,
|
||||||
|
point_forecast, ci_low, ci_high, user_triggered,
|
||||||
|
features_snapshot
|
||||||
|
FROM predictions
|
||||||
|
WHERE code = :c AND base_date = :bd AND model = 'ensemble'
|
||||||
|
ORDER BY horizon
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "bd": base_date},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# base_close (그 날 ohlcv) 도 같이 반환 — UI 가 차트 마지막 점에 이어붙일 때 사용
|
||||||
|
base_close_row = conn.execute(
|
||||||
|
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
||||||
|
{"c": code, "d": base_date},
|
||||||
|
).first()
|
||||||
|
base_close = float(base_close_row[0]) if base_close_row and base_close_row[0] is not None else None
|
||||||
|
|
||||||
|
steps = []
|
||||||
|
for r in rows:
|
||||||
|
feats = r[12]
|
||||||
|
if isinstance(feats, str):
|
||||||
|
try:
|
||||||
|
feats = json.loads(feats)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
pass
|
||||||
|
steps.append(
|
||||||
|
{
|
||||||
|
"predicted_at": r[0].isoformat() if r[0] else None,
|
||||||
|
"target_date": str(r[1]),
|
||||||
|
"horizon": int(r[2]),
|
||||||
|
"direction": r[3],
|
||||||
|
"prob_up": float(r[4]) if r[4] is not None else None,
|
||||||
|
"prob_flat": float(r[5]) if r[5] is not None else None,
|
||||||
|
"prob_down": float(r[6]) if r[6] is not None else None,
|
||||||
|
"expected_return": float(r[7]) if r[7] is not None else None,
|
||||||
|
"point_close": float(r[8]) if r[8] is not None else None,
|
||||||
|
"ci_low": float(r[9]) if r[9] is not None else None,
|
||||||
|
"ci_high": float(r[10]) if r[10] is not None else None,
|
||||||
|
"user_triggered": bool(r[11]),
|
||||||
|
"features_snapshot": feats,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": symbol[0],
|
||||||
|
"name": symbol[1],
|
||||||
|
"found": True,
|
||||||
|
"base_date": str(base_date),
|
||||||
|
"base_close": base_close,
|
||||||
|
"steps": steps,
|
||||||
|
}
|
||||||
101
backend/app/api/symbols.py
Normal file
101
backend/app/api/symbols.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""종목 검색 / 메타 API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/symbols", tags=["symbols"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search")
|
||||||
|
def search_symbols(
|
||||||
|
q: str = Query(..., min_length=1, max_length=40, description="종목명 또는 코드 prefix/부분 일치"),
|
||||||
|
limit: int = Query(default=20, ge=1, le=100),
|
||||||
|
seed_only: bool = Query(default=False, description="true 면 학습/배치 대상 10종목만"),
|
||||||
|
) -> dict:
|
||||||
|
"""이름은 trigram + ILIKE, 코드는 prefix 매치.
|
||||||
|
|
||||||
|
우선순위:
|
||||||
|
1) 코드가 정확히 같으면 가장 위
|
||||||
|
2) 이름 prefix 매치
|
||||||
|
3) 이름 부분 매치 (trigram similarity)
|
||||||
|
"""
|
||||||
|
q_norm = q.strip()
|
||||||
|
if not q_norm:
|
||||||
|
raise HTTPException(status_code=400, detail="empty query")
|
||||||
|
|
||||||
|
eng = get_engine()
|
||||||
|
where_seed = "AND is_seed = TRUE" if seed_only else ""
|
||||||
|
sql = text(
|
||||||
|
f"""
|
||||||
|
WITH ranked AS (
|
||||||
|
SELECT code, name, market, sector, is_seed,
|
||||||
|
CASE
|
||||||
|
WHEN code = :q THEN 0
|
||||||
|
WHEN code LIKE :prefix THEN 1
|
||||||
|
WHEN name LIKE :prefix THEN 2
|
||||||
|
WHEN name ILIKE :contains THEN 3
|
||||||
|
ELSE 4
|
||||||
|
END AS rank,
|
||||||
|
similarity(name, :q) AS sim
|
||||||
|
FROM symbols
|
||||||
|
WHERE active = TRUE
|
||||||
|
{where_seed}
|
||||||
|
AND (code LIKE :prefix OR name ILIKE :contains OR similarity(name, :q) > 0.2)
|
||||||
|
)
|
||||||
|
SELECT code, name, market, sector, is_seed
|
||||||
|
FROM ranked
|
||||||
|
ORDER BY rank ASC, sim DESC, name ASC
|
||||||
|
LIMIT :lim
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
sql,
|
||||||
|
{
|
||||||
|
"q": q_norm,
|
||||||
|
"prefix": f"{q_norm}%",
|
||||||
|
"contains": f"%{q_norm}%",
|
||||||
|
"lim": limit,
|
||||||
|
},
|
||||||
|
).all()
|
||||||
|
return {
|
||||||
|
"q": q_norm,
|
||||||
|
"count": len(rows),
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"code": r[0],
|
||||||
|
"name": r[1],
|
||||||
|
"market": r[2],
|
||||||
|
"sector": r[3],
|
||||||
|
"is_seed": bool(r[4]),
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}")
|
||||||
|
def get_symbol(code: str) -> dict:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT code, name, market, sector, is_seed, active, created_at "
|
||||||
|
"FROM symbols WHERE code = :c"
|
||||||
|
),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
return {
|
||||||
|
"code": row[0],
|
||||||
|
"name": row[1],
|
||||||
|
"market": row[2],
|
||||||
|
"sector": row[3],
|
||||||
|
"is_seed": bool(row[4]),
|
||||||
|
"active": bool(row[5]),
|
||||||
|
"created_at": str(row[6]) if row[6] else None,
|
||||||
|
}
|
||||||
53
backend/app/db/migrate.py
Normal file
53
backend/app/db/migrate.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Manual migration runner.
|
||||||
|
|
||||||
|
docker-entrypoint-initdb.d 는 fresh DB 첫 기동 때만 동작. 이미 동작 중인 DB 에
|
||||||
|
새 마이그레이션을 적용하려면 이 스크립트로:
|
||||||
|
|
||||||
|
python -m app.db.migrate
|
||||||
|
|
||||||
|
모든 SQL 파일은 idempotent (CREATE IF NOT EXISTS / CREATE OR REPLACE) 여야 함.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MIGRATIONS_DIR = Path(__file__).parent / "migrations"
|
||||||
|
|
||||||
|
|
||||||
|
def apply_all() -> dict[str, str]:
|
||||||
|
"""migrations/ 안 .sql 들을 이름순으로 적용. 결과: {filename: 'ok'|'failed: ...'}."""
|
||||||
|
eng = get_engine()
|
||||||
|
results: dict[str, str] = {}
|
||||||
|
files = sorted(MIGRATIONS_DIR.glob("*.sql"))
|
||||||
|
if not files:
|
||||||
|
logger.warning("no migration files in %s", MIGRATIONS_DIR)
|
||||||
|
return results
|
||||||
|
for path in files:
|
||||||
|
sql = path.read_text(encoding="utf-8")
|
||||||
|
# psql meta-command 제거 (\set ON_ERROR_STOP 등)
|
||||||
|
cleaned = "\n".join(
|
||||||
|
ln for ln in sql.splitlines() if not ln.strip().startswith("\\")
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with eng.begin() as conn:
|
||||||
|
conn.execute(text(cleaned))
|
||||||
|
results[path.name] = "ok"
|
||||||
|
logger.info("applied %s", path.name)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
results[path.name] = f"failed: {exc}"
|
||||||
|
logger.exception("migration %s failed", path.name)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
out = apply_all()
|
||||||
|
for k, v in out.items():
|
||||||
|
print(f"{k}: {v}")
|
||||||
32
backend/app/db/migrations/002_sentiment_view.sql
Normal file
32
backend/app/db/migrations/002_sentiment_view.sql
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
-- Phase 2: 일별 종목별 감성 집계 뷰.
|
||||||
|
-- weighted_score : 소스별 가중치 적용
|
||||||
|
-- naver_finance 1.0 (가장 직접적인 종목 페이지 뉴스)
|
||||||
|
-- google_rss 0.7 (관련성 노이즈 있음)
|
||||||
|
-- dart 0.5 (공시는 short title 만으로는 감성이 약함)
|
||||||
|
|
||||||
|
\set ON_ERROR_STOP on
|
||||||
|
|
||||||
|
CREATE OR REPLACE VIEW v_sentiment_daily AS
|
||||||
|
SELECT
|
||||||
|
code,
|
||||||
|
(published_at AT TIME ZONE 'Asia/Seoul')::date AS date,
|
||||||
|
COUNT(*) AS n_articles,
|
||||||
|
AVG(sentiment_score)::REAL AS mean_score,
|
||||||
|
AVG(CASE WHEN sentiment_label = 'positive' THEN 1.0 ELSE 0.0 END)::REAL AS pos_ratio,
|
||||||
|
AVG(CASE WHEN sentiment_label = 'negative' THEN 1.0 ELSE 0.0 END)::REAL AS neg_ratio,
|
||||||
|
AVG(CASE WHEN sentiment_label = 'neutral' THEN 1.0 ELSE 0.0 END)::REAL AS neu_ratio,
|
||||||
|
AVG(
|
||||||
|
sentiment_score * CASE source
|
||||||
|
WHEN 'naver_finance' THEN 1.0
|
||||||
|
WHEN 'google_rss' THEN 0.7
|
||||||
|
WHEN 'dart' THEN 0.5
|
||||||
|
ELSE 0.6
|
||||||
|
END
|
||||||
|
)::REAL AS weighted_score
|
||||||
|
FROM news
|
||||||
|
WHERE sentiment_score IS NOT NULL
|
||||||
|
AND code IS NOT NULL
|
||||||
|
GROUP BY code, (published_at AT TIME ZONE 'Asia/Seoul')::date;
|
||||||
|
|
||||||
|
COMMENT ON VIEW v_sentiment_daily IS
|
||||||
|
'Phase 2: KR-FinBERT 점수를 종목·일(KST) 단위로 집계. Phase 4 LGBM 피처 + UI 차트 보조 데이터로 사용.';
|
||||||
19
backend/app/db/migrations/003_ensemble_weights.sql
Normal file
19
backend/app/db/migrations/003_ensemble_weights.sql
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
-- Phase 4: 앙상블 가중치 저장.
|
||||||
|
-- (code, horizon) 별로 Chronos vs LGBM 가중치. 일요일 02:00 재학습 잡에서 갱신.
|
||||||
|
|
||||||
|
\set ON_ERROR_STOP on
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS ensemble_weights (
|
||||||
|
code TEXT NOT NULL REFERENCES symbols(code),
|
||||||
|
horizon INT NOT NULL,
|
||||||
|
w_chronos REAL NOT NULL DEFAULT 0.6,
|
||||||
|
w_lgbm REAL NOT NULL DEFAULT 0.4,
|
||||||
|
hit_rate_chronos REAL,
|
||||||
|
hit_rate_lgbm REAL,
|
||||||
|
sample_count INT,
|
||||||
|
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (code, horizon)
|
||||||
|
);
|
||||||
|
|
||||||
|
COMMENT ON TABLE ensemble_weights IS
|
||||||
|
'Phase 4: (code, horizon) 별 Chronos/LGBM 가중치. 최근 30일 prediction_outcomes hit_rate 기반 매주 갱신.';
|
||||||
@@ -6,7 +6,12 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.api.chart import router as chart_router
|
||||||
|
from app.api.metrics import router as metrics_router
|
||||||
|
from app.api.news import router as news_router
|
||||||
|
from app.api.predict import router as predict_router
|
||||||
from app.api.refresh import router as refresh_router
|
from app.api.refresh import router as refresh_router
|
||||||
|
from app.api.symbols import router as symbols_router
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.db.connection import ping as db_ping
|
from app.db.connection import ping as db_ping
|
||||||
from app.fetch import dart as dart_mod
|
from app.fetch import dart as dart_mod
|
||||||
@@ -41,6 +46,11 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(refresh_router)
|
app.include_router(refresh_router)
|
||||||
|
app.include_router(symbols_router)
|
||||||
|
app.include_router(chart_router)
|
||||||
|
app.include_router(predict_router)
|
||||||
|
app.include_router(metrics_router)
|
||||||
|
app.include_router(news_router)
|
||||||
|
|
||||||
|
|
||||||
def _resolved_device() -> str:
|
def _resolved_device() -> str:
|
||||||
|
|||||||
118
backend/app/models/chronos.py
Normal file
118
backend/app/models/chronos.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""Chronos zero-shot 시계열 예측 어댑터.
|
||||||
|
|
||||||
|
모델: amazon/chronos-t5-small (46M, 빠르고 RTX 3070 Ti 에 충분히 들어감).
|
||||||
|
환경변수 CHRONOS_MODEL 로 base/large 로 바꿀 수 있음.
|
||||||
|
|
||||||
|
입력: 종가 시계열 (list[float], 최소 32 step).
|
||||||
|
출력: horizon 일 quantile forecast (q10/median/q90).
|
||||||
|
|
||||||
|
lazy singleton 으로 첫 호출 시 모델 로드. 디바이스는 settings.model_device 따라.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_NAME = os.environ.get("CHRONOS_MODEL", "amazon/chronos-t5-small")
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_state: dict[str, object] = {"loaded": False, "pipe": None, "device": None}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChronosForecast:
|
||||||
|
horizon: int
|
||||||
|
median: list[float]
|
||||||
|
q10: list[float]
|
||||||
|
q90: list[float]
|
||||||
|
samples: list[list[float]] # raw samples for ensemble downstream
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_device() -> str:
|
||||||
|
import torch # lazy
|
||||||
|
|
||||||
|
pref = (settings.model_device or "auto").lower()
|
||||||
|
if pref == "cuda":
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if pref == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _load() -> None:
|
||||||
|
global _state
|
||||||
|
with _lock:
|
||||||
|
if _state["loaded"]:
|
||||||
|
return
|
||||||
|
import torch
|
||||||
|
from chronos import ChronosPipeline
|
||||||
|
|
||||||
|
token = settings.huggingface_token or None
|
||||||
|
if token:
|
||||||
|
os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token)
|
||||||
|
os.environ.setdefault("HF_TOKEN", token)
|
||||||
|
|
||||||
|
device = _resolve_device()
|
||||||
|
# bf16 은 RTX 30xx 이상에서 지원. cpu 에선 fp32.
|
||||||
|
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||||
|
logger.info("loading Chronos %s on %s (dtype=%s)", MODEL_NAME, device, dtype)
|
||||||
|
pipe = ChronosPipeline.from_pretrained(
|
||||||
|
MODEL_NAME,
|
||||||
|
device_map=device,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
_state.update({"loaded": True, "pipe": pipe, "device": device})
|
||||||
|
|
||||||
|
|
||||||
|
def forecast(
|
||||||
|
series: list[float],
|
||||||
|
*,
|
||||||
|
horizon: int = 5,
|
||||||
|
num_samples: int = 30,
|
||||||
|
) -> ChronosForecast:
|
||||||
|
"""series 의 마지막 시점 이후 horizon 일 예측.
|
||||||
|
|
||||||
|
series 는 일봉 종가. 최소 32개 권장 (그보다 짧으면 Chronos 분위 안정성 떨어짐).
|
||||||
|
"""
|
||||||
|
if len(series) < 32:
|
||||||
|
raise ValueError(
|
||||||
|
f"series too short ({len(series)}) for Chronos forecast (need >=32)"
|
||||||
|
)
|
||||||
|
_load()
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipe = _state["pipe"]
|
||||||
|
context = torch.tensor([float(x) for x in series], dtype=torch.float32)
|
||||||
|
with torch.no_grad():
|
||||||
|
samples = pipe.predict(
|
||||||
|
context=context,
|
||||||
|
prediction_length=horizon,
|
||||||
|
num_samples=num_samples,
|
||||||
|
)
|
||||||
|
# samples: (1, num_samples, prediction_length)
|
||||||
|
arr = samples[0].cpu().float().numpy()
|
||||||
|
q10 = np.quantile(arr, 0.10, axis=0).tolist()
|
||||||
|
median = np.quantile(arr, 0.50, axis=0).tolist()
|
||||||
|
q90 = np.quantile(arr, 0.90, axis=0).tolist()
|
||||||
|
return ChronosForecast(
|
||||||
|
horizon=horizon,
|
||||||
|
median=[float(x) for x in median],
|
||||||
|
q10=[float(x) for x in q10],
|
||||||
|
q90=[float(x) for x in q90],
|
||||||
|
samples=[[float(x) for x in row] for row in arr.tolist()],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ping() -> dict[str, object]:
|
||||||
|
try:
|
||||||
|
_load()
|
||||||
|
return {"status": "ok", "model": MODEL_NAME, "device": _state["device"]}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return {"status": "failed", "model": MODEL_NAME, "error": str(exc)}
|
||||||
174
backend/app/models/ensemble.py
Normal file
174
backend/app/models/ensemble.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""Chronos + LGBM 앙상블 추론.
|
||||||
|
|
||||||
|
final_price[h] = w_c * chronos.median[h-1] + w_l * lgbm.predicted_close
|
||||||
|
final_q10[h] = w_c * chronos.q10[h-1] + w_l * lgbm.predicted_close * 0.97
|
||||||
|
final_q90[h] = w_c * chronos.q90[h-1] + w_l * lgbm.predicted_close * 1.03
|
||||||
|
|
||||||
|
LGBM 은 단일 horizon 의 다음 종가(point) 만 주므로, 그 자체로는 신뢰구간이 없음.
|
||||||
|
근사로 ±3% band 를 LGBM 의 q10/q90 자리에 사용. Chronos 의 sample 분포가
|
||||||
|
주된 신뢰구간 정보 (Chronos 우세하면 ci 가 좁아짐).
|
||||||
|
|
||||||
|
direction 확률:
|
||||||
|
- LGBM 분류기에서 prob_up/flat/down (3-class) 그대로
|
||||||
|
- Chronos 는 next-day return 부호 비율: samples.shift1 / base_close - 1 의 부호
|
||||||
|
- 둘을 같은 가중치로 평균
|
||||||
|
|
||||||
|
LGBM 모델이 없으면 Chronos 단독으로 진행 (cold start).
|
||||||
|
Chronos 도 실패하면 LGBM 단독으로 진행.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.models.chronos import ChronosForecast
|
||||||
|
from app.models.chronos import forecast as chronos_forecast
|
||||||
|
from app.models.lgbm import LgbmForecast
|
||||||
|
from app.models.lgbm import predict_one as lgbm_predict
|
||||||
|
from app.models.weights import load_weights
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnsembleStep:
|
||||||
|
horizon: int # 1..H 거래일 후
|
||||||
|
target_idx: int # chronos median 의 0-based 인덱스 (horizon-1)
|
||||||
|
point_close: float
|
||||||
|
ci_low: float
|
||||||
|
ci_high: float
|
||||||
|
prob_up: float
|
||||||
|
prob_flat: float
|
||||||
|
prob_down: float
|
||||||
|
direction: str # 'up' / 'flat' / 'down'
|
||||||
|
expected_return: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnsemblePrediction:
|
||||||
|
code: str
|
||||||
|
base_close: float
|
||||||
|
horizons: list[int]
|
||||||
|
steps: list[EnsembleStep]
|
||||||
|
sources_used: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
def _chronos_direction(samples: list[list[float]], base_close: float, horizon: int) -> tuple[float, float, float]:
|
||||||
|
"""Chronos sample 분포에서 (prob_up, prob_flat, prob_down). ±0.3% flat band."""
|
||||||
|
if not samples:
|
||||||
|
return 0.33, 0.34, 0.33
|
||||||
|
arr = np.array(samples)[:, horizon - 1] # 해당 step 의 sample 값
|
||||||
|
ret = arr / base_close - 1.0
|
||||||
|
p_up = float((ret > 0.003).mean())
|
||||||
|
p_dn = float((ret < -0.003).mean())
|
||||||
|
p_fl = 1.0 - p_up - p_dn
|
||||||
|
return p_up, p_fl, p_dn
|
||||||
|
|
||||||
|
|
||||||
|
def predict(code: str, *, horizons: tuple[int, ...] = (1, 3, 5)) -> EnsemblePrediction:
|
||||||
|
"""한 종목에 대해 horizons 별 앙상블 예측. on-demand 추론용."""
|
||||||
|
max_h = max(horizons)
|
||||||
|
|
||||||
|
# Chronos: 종가 시계열 가져와서 max_h 까지 예측.
|
||||||
|
from app.models.features import build_features # local import
|
||||||
|
|
||||||
|
ff = build_features(code, lookback_days=400, horizons=horizons, with_targets=False)
|
||||||
|
df = ff.df
|
||||||
|
if df.empty:
|
||||||
|
raise RuntimeError(f"no OHLCV data for {code}")
|
||||||
|
closes = df["close"].astype(float).tolist()
|
||||||
|
base_close = float(closes[-1])
|
||||||
|
|
||||||
|
sources_used: list[str] = []
|
||||||
|
cf: ChronosForecast | None = None
|
||||||
|
try:
|
||||||
|
cf = chronos_forecast(closes, horizon=max_h, num_samples=30)
|
||||||
|
sources_used.append("chronos")
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("chronos forecast failed for %s: %s", code, exc)
|
||||||
|
|
||||||
|
steps: list[EnsembleStep] = []
|
||||||
|
for h in horizons:
|
||||||
|
lf: LgbmForecast | None = None
|
||||||
|
try:
|
||||||
|
lf = lgbm_predict(code, h)
|
||||||
|
if lf is not None:
|
||||||
|
sources_used.append(f"lgbm_h{h}")
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("lgbm predict failed for %s h=%d: %s", code, h, exc)
|
||||||
|
|
||||||
|
# 가중치 (DB 없으면 default 0.6/0.4).
|
||||||
|
w = load_weights(code, h)
|
||||||
|
wc, wl = w.w_chronos, w.w_lgbm
|
||||||
|
# 한쪽이 없으면 다른 쪽 전부.
|
||||||
|
if cf is None and lf is None:
|
||||||
|
raise RuntimeError(f"both chronos & lgbm failed for {code} h={h}")
|
||||||
|
if cf is None:
|
||||||
|
wc, wl = 0.0, 1.0
|
||||||
|
if lf is None:
|
||||||
|
wc, wl = 1.0, 0.0
|
||||||
|
|
||||||
|
if cf is not None:
|
||||||
|
c_med = cf.median[h - 1]
|
||||||
|
c_q10 = cf.q10[h - 1]
|
||||||
|
c_q90 = cf.q90[h - 1]
|
||||||
|
else:
|
||||||
|
c_med = c_q10 = c_q90 = base_close # not used (wc=0)
|
||||||
|
|
||||||
|
if lf is not None:
|
||||||
|
l_close = lf.predicted_close
|
||||||
|
l_lo = l_close * 0.97
|
||||||
|
l_hi = l_close * 1.03
|
||||||
|
l_pu, l_pf, l_pd = lf.prob_up, lf.prob_flat, lf.prob_down
|
||||||
|
else:
|
||||||
|
l_close = l_lo = l_hi = base_close
|
||||||
|
l_pu = l_pf = l_pd = 0.0
|
||||||
|
|
||||||
|
point = wc * c_med + wl * l_close
|
||||||
|
lo = wc * c_q10 + wl * l_lo
|
||||||
|
hi = wc * c_q90 + wl * l_hi
|
||||||
|
|
||||||
|
if cf is not None:
|
||||||
|
cp_up, cp_fl, cp_dn = _chronos_direction(cf.samples, base_close, h)
|
||||||
|
else:
|
||||||
|
cp_up = cp_fl = cp_dn = 0.0
|
||||||
|
|
||||||
|
# direction prob: source 마다 weights 동일하게 가중평균
|
||||||
|
if lf is not None and cf is not None:
|
||||||
|
p_up = 0.5 * cp_up + 0.5 * l_pu
|
||||||
|
p_fl = 0.5 * cp_fl + 0.5 * l_pf
|
||||||
|
p_dn = 0.5 * cp_dn + 0.5 * l_pd
|
||||||
|
elif cf is not None:
|
||||||
|
p_up, p_fl, p_dn = cp_up, cp_fl, cp_dn
|
||||||
|
else:
|
||||||
|
p_up, p_fl, p_dn = l_pu, l_pf, l_pd
|
||||||
|
|
||||||
|
# 정규화 (혹시 합이 0 가 아닐 때)
|
||||||
|
s = max(p_up + p_fl + p_dn, 1e-9)
|
||||||
|
p_up, p_fl, p_dn = p_up / s, p_fl / s, p_dn / s
|
||||||
|
dir_lbl = "up" if p_up >= max(p_fl, p_dn) else ("down" if p_dn >= p_fl else "flat")
|
||||||
|
|
||||||
|
steps.append(
|
||||||
|
EnsembleStep(
|
||||||
|
horizon=h,
|
||||||
|
target_idx=h - 1,
|
||||||
|
point_close=float(point),
|
||||||
|
ci_low=float(lo),
|
||||||
|
ci_high=float(hi),
|
||||||
|
prob_up=float(p_up),
|
||||||
|
prob_flat=float(p_fl),
|
||||||
|
prob_down=float(p_dn),
|
||||||
|
direction=dir_lbl,
|
||||||
|
expected_return=float(point / base_close - 1.0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return EnsemblePrediction(
|
||||||
|
code=code,
|
||||||
|
base_close=base_close,
|
||||||
|
horizons=list(horizons),
|
||||||
|
steps=steps,
|
||||||
|
sources_used=sources_used,
|
||||||
|
)
|
||||||
223
backend/app/models/features.py
Normal file
223
backend/app/models/features.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""모델 학습/추론용 피처 빌더.
|
||||||
|
|
||||||
|
종목 1개 + 룩백 기간을 받아 (date 단위) DataFrame 반환:
|
||||||
|
- OHLCV
|
||||||
|
- returns r1
|
||||||
|
- TA: rsi14, macd, macd_signal, atr14, bb_pct, sma20, ema12, vol_z20
|
||||||
|
- trading_value: foreign_net, institution_net, individual_net (정규화 X, scale 그대로)
|
||||||
|
- macro 정렬: kospi, kosdaq, usdkrw, us10y, kospi_r1, usdkrw_r1
|
||||||
|
- sentiment (v_sentiment_daily): mean_score, weighted_score, n_articles,
|
||||||
|
pos_minus_neg = pos_ratio - neg_ratio. 3일 롤링 mean 도 추가.
|
||||||
|
|
||||||
|
학습 타깃 (build_features 에서만 생성):
|
||||||
|
- y_close_h{1,3,5}: close.shift(-H)
|
||||||
|
- y_ret_h{1,3,5}: y_close_h / close - 1
|
||||||
|
- y_dir_h{1,3,5}: sign(y_ret_h) (1=up, -1=down, 0=flat ±0.3% 이내)
|
||||||
|
|
||||||
|
inference 용 build_features 는 dropna 안 함. 학습용 build_training_frame 은 dropna.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FLAT_BAND = 0.003 # ±0.3% 이내는 flat
|
||||||
|
HORIZONS_DEFAULT = (1, 3, 5)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureFrame:
|
||||||
|
code: str
|
||||||
|
df: pd.DataFrame
|
||||||
|
target_horizons: tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_ohlcv(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, open, high, low, close, volume
|
||||||
|
FROM ohlcv_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date", "open", "high", "low", "close", "volume"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "open", "high", "low", "close", "volume"])
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _load_trading(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, foreign_net, institution_net, individual_net
|
||||||
|
FROM trading_value_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date", "foreign_net", "institution_net", "individual_net"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "foreign_net", "institution_net", "individual_net"])
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _load_macro(start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"SELECT date, key, value FROM macro_daily "
|
||||||
|
"WHERE date BETWEEN :s AND :e ORDER BY date"
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "key", "value"])
|
||||||
|
pivot = df.pivot_table(index="date", columns="key", values="value", aggfunc="last").reset_index()
|
||||||
|
pivot["date"] = pd.to_datetime(pivot["date"]).dt.date
|
||||||
|
pivot.columns.name = None
|
||||||
|
return pivot
|
||||||
|
|
||||||
|
|
||||||
|
def _load_sentiment(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, n_articles, mean_score, pos_ratio, neg_ratio,
|
||||||
|
weighted_score
|
||||||
|
FROM v_sentiment_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
cols = ["date", "n_articles", "mean_score", "pos_ratio", "neg_ratio", "weighted_score"]
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=cols)
|
||||||
|
df = pd.DataFrame(rows, columns=cols)
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
df["pos_minus_neg"] = df["pos_ratio"].fillna(0) - df["neg_ratio"].fillna(0)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _add_ta(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""ta 패키지로 기술 지표 추가."""
|
||||||
|
from ta.momentum import RSIIndicator
|
||||||
|
from ta.trend import EMAIndicator, MACD, SMAIndicator
|
||||||
|
from ta.volatility import AverageTrueRange, BollingerBands
|
||||||
|
|
||||||
|
close = df["close"].astype(float)
|
||||||
|
high = df["high"].astype(float)
|
||||||
|
low = df["low"].astype(float)
|
||||||
|
vol = df["volume"].astype(float)
|
||||||
|
|
||||||
|
df["r1"] = close.pct_change()
|
||||||
|
df["rsi14"] = RSIIndicator(close=close, window=14, fillna=False).rsi()
|
||||||
|
macd = MACD(close=close, window_slow=26, window_fast=12, window_sign=9, fillna=False)
|
||||||
|
df["macd"] = macd.macd()
|
||||||
|
df["macd_signal"] = macd.macd_signal()
|
||||||
|
df["atr14"] = AverageTrueRange(high=high, low=low, close=close, window=14, fillna=False).average_true_range()
|
||||||
|
bb = BollingerBands(close=close, window=20, window_dev=2, fillna=False)
|
||||||
|
df["bb_pct"] = bb.bollinger_pband()
|
||||||
|
df["sma20"] = SMAIndicator(close=close, window=20, fillna=False).sma_indicator()
|
||||||
|
df["ema12"] = EMAIndicator(close=close, window=12, fillna=False).ema_indicator()
|
||||||
|
vol_mean = vol.rolling(20).mean()
|
||||||
|
vol_std = vol.rolling(20).std().replace(0, np.nan)
|
||||||
|
df["vol_z20"] = (vol - vol_mean) / vol_std
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _add_targets(df: pd.DataFrame, horizons: tuple[int, ...]) -> pd.DataFrame:
|
||||||
|
close = df["close"].astype(float)
|
||||||
|
for h in horizons:
|
||||||
|
df[f"y_close_h{h}"] = close.shift(-h)
|
||||||
|
df[f"y_ret_h{h}"] = df[f"y_close_h{h}"] / close - 1.0
|
||||||
|
df[f"y_dir_h{h}"] = np.where(
|
||||||
|
df[f"y_ret_h{h}"] > FLAT_BAND, 1,
|
||||||
|
np.where(df[f"y_ret_h{h}"] < -FLAT_BAND, -1, 0),
|
||||||
|
)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def build_features(
|
||||||
|
code: str,
|
||||||
|
*,
|
||||||
|
lookback_days: int = 365 * 2,
|
||||||
|
end_date: date | None = None,
|
||||||
|
horizons: tuple[int, ...] = HORIZONS_DEFAULT,
|
||||||
|
with_targets: bool = False,
|
||||||
|
) -> FeatureFrame:
|
||||||
|
"""code 1개 종목의 피처 DataFrame 생성.
|
||||||
|
|
||||||
|
inference: with_targets=False 로 호출 → 최신 row 의 피처만 LGBM/Chronos 에 투입.
|
||||||
|
training : with_targets=True 로 호출 → tail H 행은 타깃 NaN → dropna 로 제거.
|
||||||
|
"""
|
||||||
|
end = end_date or date.today()
|
||||||
|
start = end - timedelta(days=lookback_days)
|
||||||
|
|
||||||
|
ohlcv = _load_ohlcv(code, start, end)
|
||||||
|
if ohlcv.empty:
|
||||||
|
return FeatureFrame(code=code, df=ohlcv, target_horizons=horizons)
|
||||||
|
|
||||||
|
df = ohlcv.copy().sort_values("date").reset_index(drop=True)
|
||||||
|
|
||||||
|
df = _add_ta(df)
|
||||||
|
|
||||||
|
trading = _load_trading(code, start, end)
|
||||||
|
if not trading.empty:
|
||||||
|
df = df.merge(trading, on="date", how="left")
|
||||||
|
else:
|
||||||
|
for col in ("foreign_net", "institution_net", "individual_net"):
|
||||||
|
df[col] = np.nan
|
||||||
|
|
||||||
|
macro = _load_macro(start, end)
|
||||||
|
if not macro.empty:
|
||||||
|
df = df.merge(macro, on="date", how="left")
|
||||||
|
for k in ("kospi", "kosdaq", "usdkrw", "us10y"):
|
||||||
|
if k in df.columns:
|
||||||
|
df[f"{k}_r1"] = df[k].pct_change()
|
||||||
|
|
||||||
|
sentiment = _load_sentiment(code, start, end)
|
||||||
|
if not sentiment.empty:
|
||||||
|
df = df.merge(sentiment, on="date", how="left")
|
||||||
|
# 3일 롤링 평균
|
||||||
|
for col in ("mean_score", "weighted_score", "pos_minus_neg", "n_articles"):
|
||||||
|
if col in df.columns:
|
||||||
|
df[f"{col}_3d"] = df[col].rolling(3, min_periods=1).mean()
|
||||||
|
else:
|
||||||
|
for col in ("n_articles", "mean_score", "pos_ratio", "neg_ratio",
|
||||||
|
"weighted_score", "pos_minus_neg"):
|
||||||
|
df[col] = np.nan
|
||||||
|
|
||||||
|
if with_targets:
|
||||||
|
df = _add_targets(df, horizons)
|
||||||
|
|
||||||
|
return FeatureFrame(code=code, df=df, target_horizons=horizons)
|
||||||
|
|
||||||
|
|
||||||
|
def feature_columns(df: pd.DataFrame) -> list[str]:
|
||||||
|
"""LGBM 학습/추론용 피처 컬럼 목록. date / OHLCV / y_* 제외."""
|
||||||
|
drop = {"date", "open", "high", "low", "close", "volume"}
|
||||||
|
cols = [
|
||||||
|
c for c in df.columns
|
||||||
|
if c not in drop and not c.startswith("y_")
|
||||||
|
]
|
||||||
|
return cols
|
||||||
180
backend/app/models/lgbm.py
Normal file
180
backend/app/models/lgbm.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""LightGBM 회귀 + 분류 모델. 종목 × horizon 별 별도 저장.
|
||||||
|
|
||||||
|
- 회귀: target = y_ret_h{H}. 예측 후 base_close*(1+pred) 로 가격 환산.
|
||||||
|
- 분류: target = y_dir_h{H} ∈ {-1, 0, +1}. 3-class softmax 로 prob_up/flat/down.
|
||||||
|
|
||||||
|
저장 경로: backend/data/models/{code}_h{H}_reg.pkl, _cls.pkl (joblib).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from app.models.features import build_features, feature_columns
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_DIR = Path(os.environ.get("LGBM_MODEL_DIR", "/app/data/models"))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LgbmForecast:
|
||||||
|
horizon: int
|
||||||
|
base_close: float
|
||||||
|
predicted_close: float
|
||||||
|
predicted_return: float
|
||||||
|
prob_up: float
|
||||||
|
prob_flat: float
|
||||||
|
prob_down: float
|
||||||
|
|
||||||
|
|
||||||
|
def _model_paths(code: str, horizon: int) -> tuple[Path, Path]:
|
||||||
|
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
return (
|
||||||
|
MODEL_DIR / f"{code}_h{horizon}_reg.pkl",
|
||||||
|
MODEL_DIR / f"{code}_h{horizon}_cls.pkl",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_xy(code: str, horizon: int, lookback_days: int) -> tuple[pd.DataFrame, pd.Series, pd.Series, list[str]]:
|
||||||
|
ff = build_features(
|
||||||
|
code,
|
||||||
|
lookback_days=lookback_days,
|
||||||
|
horizons=(horizon,),
|
||||||
|
with_targets=True,
|
||||||
|
)
|
||||||
|
df = ff.df
|
||||||
|
if df.empty:
|
||||||
|
return df, pd.Series(dtype=float), pd.Series(dtype=int), []
|
||||||
|
y_ret_col = f"y_ret_h{horizon}"
|
||||||
|
y_dir_col = f"y_dir_h{horizon}"
|
||||||
|
# 타깃 NaN (마지막 H 행) 제거.
|
||||||
|
df = df.dropna(subset=[y_ret_col, y_dir_col])
|
||||||
|
feats = feature_columns(df)
|
||||||
|
if not feats:
|
||||||
|
return df, pd.Series(dtype=float), pd.Series(dtype=int), []
|
||||||
|
X = df[feats]
|
||||||
|
# LightGBM 은 NaN 자체 처리 가능.
|
||||||
|
y_ret = df[y_ret_col].astype(float)
|
||||||
|
y_dir = df[y_dir_col].astype(int)
|
||||||
|
return X, y_ret, y_dir, feats
|
||||||
|
|
||||||
|
|
||||||
|
def train_one(code: str, horizon: int, *, lookback_days: int = 365 * 3) -> dict:
|
||||||
|
"""1종목 × 1 horizon 학습. 저장된 모델 파일 경로 + 샘플 수 반환."""
|
||||||
|
import lightgbm as lgb
|
||||||
|
|
||||||
|
X, y_ret, y_dir, feats = _prepare_xy(code, horizon, lookback_days)
|
||||||
|
if X.empty or len(X) < 100:
|
||||||
|
return {"code": code, "horizon": horizon, "status": "skipped_too_few_rows", "n_rows": int(len(X))}
|
||||||
|
|
||||||
|
reg_params = dict(
|
||||||
|
objective="regression",
|
||||||
|
learning_rate=0.05,
|
||||||
|
num_leaves=31,
|
||||||
|
min_data_in_leaf=20,
|
||||||
|
feature_fraction=0.85,
|
||||||
|
bagging_fraction=0.8,
|
||||||
|
bagging_freq=5,
|
||||||
|
verbose=-1,
|
||||||
|
)
|
||||||
|
cls_params = dict(
|
||||||
|
objective="multiclass",
|
||||||
|
num_class=3,
|
||||||
|
learning_rate=0.05,
|
||||||
|
num_leaves=31,
|
||||||
|
min_data_in_leaf=20,
|
||||||
|
feature_fraction=0.85,
|
||||||
|
bagging_fraction=0.8,
|
||||||
|
bagging_freq=5,
|
||||||
|
verbose=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 분류는 -1/0/1 → 0/1/2 인덱스로 매핑.
|
||||||
|
y_dir_idx = (y_dir + 1).astype(int)
|
||||||
|
|
||||||
|
n = len(X)
|
||||||
|
split = int(n * 0.85)
|
||||||
|
X_tr, X_val = X.iloc[:split], X.iloc[split:]
|
||||||
|
yr_tr, yr_val = y_ret.iloc[:split], y_ret.iloc[split:]
|
||||||
|
yc_tr, yc_val = y_dir_idx.iloc[:split], y_dir_idx.iloc[split:]
|
||||||
|
|
||||||
|
reg_train = lgb.Dataset(X_tr, label=yr_tr)
|
||||||
|
reg_valid = lgb.Dataset(X_val, label=yr_val, reference=reg_train)
|
||||||
|
reg_model = lgb.train(
|
||||||
|
reg_params,
|
||||||
|
reg_train,
|
||||||
|
num_boost_round=400,
|
||||||
|
valid_sets=[reg_valid],
|
||||||
|
callbacks=[lgb.early_stopping(stopping_rounds=30, verbose=False)],
|
||||||
|
)
|
||||||
|
|
||||||
|
cls_train = lgb.Dataset(X_tr, label=yc_tr)
|
||||||
|
cls_valid = lgb.Dataset(X_val, label=yc_val, reference=cls_train)
|
||||||
|
cls_model = lgb.train(
|
||||||
|
cls_params,
|
||||||
|
cls_train,
|
||||||
|
num_boost_round=400,
|
||||||
|
valid_sets=[cls_valid],
|
||||||
|
callbacks=[lgb.early_stopping(stopping_rounds=30, verbose=False)],
|
||||||
|
)
|
||||||
|
|
||||||
|
reg_path, cls_path = _model_paths(code, horizon)
|
||||||
|
joblib.dump({"model": reg_model, "features": feats}, reg_path)
|
||||||
|
joblib.dump({"model": cls_model, "features": feats}, cls_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": code,
|
||||||
|
"horizon": horizon,
|
||||||
|
"status": "ok",
|
||||||
|
"n_rows": int(len(X)),
|
||||||
|
"reg_best_iter": int(reg_model.best_iteration or 0),
|
||||||
|
"cls_best_iter": int(cls_model.best_iteration or 0),
|
||||||
|
"reg_path": str(reg_path),
|
||||||
|
"cls_path": str(cls_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def predict_one(code: str, horizon: int, *, lookback_days: int = 400) -> LgbmForecast | None:
|
||||||
|
"""1종목 × 1 horizon 추론. 모델 없으면 None.
|
||||||
|
|
||||||
|
가장 최신 영업일 피처를 사용. base_close 는 그 행의 close.
|
||||||
|
"""
|
||||||
|
reg_path, cls_path = _model_paths(code, horizon)
|
||||||
|
if not reg_path.exists() or not cls_path.exists():
|
||||||
|
return None
|
||||||
|
reg_blob = joblib.load(reg_path)
|
||||||
|
cls_blob = joblib.load(cls_path)
|
||||||
|
feats_reg = reg_blob["features"]
|
||||||
|
feats_cls = cls_blob["features"]
|
||||||
|
reg_model = reg_blob["model"]
|
||||||
|
cls_model = cls_blob["model"]
|
||||||
|
|
||||||
|
ff = build_features(code, lookback_days=lookback_days, horizons=(horizon,), with_targets=False)
|
||||||
|
df = ff.df
|
||||||
|
if df.empty:
|
||||||
|
return None
|
||||||
|
last = df.iloc[[-1]]
|
||||||
|
base_close = float(last["close"].iloc[0])
|
||||||
|
# 피처 정렬 (모델이 학습 당시 본 컬럼 순서대로).
|
||||||
|
X_reg = last.reindex(columns=feats_reg).fillna(value=np.nan)
|
||||||
|
X_cls = last.reindex(columns=feats_cls).fillna(value=np.nan)
|
||||||
|
pred_ret = float(reg_model.predict(X_reg)[0])
|
||||||
|
probs = cls_model.predict(X_cls)[0]
|
||||||
|
# 인덱스 0=-1(down), 1=0(flat), 2=+1(up)
|
||||||
|
prob_down, prob_flat, prob_up = float(probs[0]), float(probs[1]), float(probs[2])
|
||||||
|
return LgbmForecast(
|
||||||
|
horizon=horizon,
|
||||||
|
base_close=base_close,
|
||||||
|
predicted_close=base_close * (1.0 + pred_ret),
|
||||||
|
predicted_return=pred_ret,
|
||||||
|
prob_up=prob_up,
|
||||||
|
prob_flat=prob_flat,
|
||||||
|
prob_down=prob_down,
|
||||||
|
)
|
||||||
75
backend/app/models/weights.py
Normal file
75
backend/app/models/weights.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""ensemble_weights 테이블 IO. 기본 가중치 (chronos 0.6, lgbm 0.4)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnsembleWeights:
|
||||||
|
code: str
|
||||||
|
horizon: int
|
||||||
|
w_chronos: float
|
||||||
|
w_lgbm: float
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_W_CHRONOS = 0.6
|
||||||
|
DEFAULT_W_LGBM = 0.4
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(code: str, horizon: int) -> EnsembleWeights:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT w_chronos, w_lgbm FROM ensemble_weights "
|
||||||
|
"WHERE code = :code AND horizon = :h"
|
||||||
|
),
|
||||||
|
{"code": code, "h": horizon},
|
||||||
|
).first()
|
||||||
|
if not row:
|
||||||
|
return EnsembleWeights(code, horizon, DEFAULT_W_CHRONOS, DEFAULT_W_LGBM)
|
||||||
|
return EnsembleWeights(code, horizon, float(row[0]), float(row[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_weights(
|
||||||
|
code: str,
|
||||||
|
horizon: int,
|
||||||
|
w_chronos: float,
|
||||||
|
w_lgbm: float,
|
||||||
|
*,
|
||||||
|
hit_rate_chronos: float | None = None,
|
||||||
|
hit_rate_lgbm: float | None = None,
|
||||||
|
sample_count: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO ensemble_weights
|
||||||
|
(code, horizon, w_chronos, w_lgbm, hit_rate_chronos, hit_rate_lgbm, sample_count, updated_at)
|
||||||
|
VALUES
|
||||||
|
(:code, :h, :wc, :wl, :hc, :hl, :n, NOW())
|
||||||
|
ON CONFLICT (code, horizon) DO UPDATE SET
|
||||||
|
w_chronos = EXCLUDED.w_chronos,
|
||||||
|
w_lgbm = EXCLUDED.w_lgbm,
|
||||||
|
hit_rate_chronos = EXCLUDED.hit_rate_chronos,
|
||||||
|
hit_rate_lgbm = EXCLUDED.hit_rate_lgbm,
|
||||||
|
sample_count = EXCLUDED.sample_count,
|
||||||
|
updated_at = NOW()
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"code": code,
|
||||||
|
"h": horizon,
|
||||||
|
"wc": float(w_chronos),
|
||||||
|
"wl": float(w_lgbm),
|
||||||
|
"hc": hit_rate_chronos,
|
||||||
|
"hl": hit_rate_lgbm,
|
||||||
|
"n": sample_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
0
backend/app/nlp/__init__.py
Normal file
0
backend/app/nlp/__init__.py
Normal file
150
backend/app/nlp/finbert.py
Normal file
150
backend/app/nlp/finbert.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""KR-FinBERT 감성 분석 어댑터.
|
||||||
|
|
||||||
|
모델: snunlp/KR-FinBert-SC (3-class: negative / neutral / positive)
|
||||||
|
|
||||||
|
- score : prob(positive) - prob(negative) ∈ [-1, +1]
|
||||||
|
- label : argmax 결과 ('positive' / 'neutral' / 'negative')
|
||||||
|
- embedding : 마지막 hidden state mean pool (768d) — `news.embedding` (VECTOR) 저장용
|
||||||
|
|
||||||
|
디바이스: settings.model_device ('auto' → cuda 가용 시 cuda, 아니면 cpu).
|
||||||
|
인증: settings.huggingface_token (gated 모델은 아니지만 HF rate limit 우회 + 일관성).
|
||||||
|
캐시: HF_HOME=/root/.cache/huggingface (docker-compose 의 `hf_cache` 볼륨).
|
||||||
|
|
||||||
|
lazy singleton — FastAPI 기동 시점에 모델을 로드하지 않고, 첫 score_texts() 호출
|
||||||
|
또는 ping() 호출 시점에 로드.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_NAME = "snunlp/KR-FinBert-SC"
|
||||||
|
# KR-FinBert-SC 의 id2label : {0: 'negative', 1: 'neutral', 2: 'positive'}
|
||||||
|
_LABELS = ("negative", "neutral", "positive")
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_state: dict[str, object] = {
|
||||||
|
"loaded": False,
|
||||||
|
"tokenizer": None,
|
||||||
|
"model": None,
|
||||||
|
"device": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinbertOutput:
|
||||||
|
label: str
|
||||||
|
score: float # prob_positive - prob_negative ∈ [-1, +1]
|
||||||
|
prob_negative: float
|
||||||
|
prob_neutral: float
|
||||||
|
prob_positive: float
|
||||||
|
embedding: list[float] # 768d mean-pooled last hidden state
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_device() -> str:
|
||||||
|
"""settings.model_device 값에 따라 'cuda'/'cpu' 결정."""
|
||||||
|
import torch # lazy
|
||||||
|
|
||||||
|
pref = (settings.model_device or "auto").lower()
|
||||||
|
if pref == "cuda":
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if pref == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
# 'auto'
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _load() -> None:
|
||||||
|
global _state
|
||||||
|
with _lock:
|
||||||
|
if _state["loaded"]:
|
||||||
|
return
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
token = settings.huggingface_token or None
|
||||||
|
if token:
|
||||||
|
# transformers/datasets 모두 이 env 를 인식.
|
||||||
|
os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token)
|
||||||
|
os.environ.setdefault("HF_TOKEN", token)
|
||||||
|
|
||||||
|
device = _resolve_device()
|
||||||
|
logger.info("loading %s on %s", MODEL_NAME, device)
|
||||||
|
tok = AutoTokenizer.from_pretrained(MODEL_NAME, token=token)
|
||||||
|
mdl = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
MODEL_NAME,
|
||||||
|
token=token,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
mdl.eval()
|
||||||
|
mdl.to(device)
|
||||||
|
_state.update({"loaded": True, "tokenizer": tok, "model": mdl, "device": device})
|
||||||
|
logger.info("KR-FinBERT loaded (device=%s)", device)
|
||||||
|
|
||||||
|
|
||||||
|
def score_texts(
|
||||||
|
texts: list[str],
|
||||||
|
*,
|
||||||
|
batch_size: int = 16,
|
||||||
|
max_length: int = 256,
|
||||||
|
) -> list[FinbertOutput]:
|
||||||
|
"""주어진 텍스트 리스트에 대해 감성 점수 + 라벨 + 768d embedding 반환.
|
||||||
|
|
||||||
|
빈 문자열은 placeholder('_')로 치환해서 라벨은 neutral 에 가깝게 나오게 함.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
_load()
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tok = _state["tokenizer"]
|
||||||
|
mdl = _state["model"]
|
||||||
|
device = _state["device"]
|
||||||
|
|
||||||
|
results: list[FinbertOutput] = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
chunk = [(t or "").strip() or "_" for t in texts[i : i + batch_size]]
|
||||||
|
enc = tok(
|
||||||
|
chunk,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(device)
|
||||||
|
out = mdl(**enc)
|
||||||
|
probs = torch.softmax(out.logits, dim=-1).cpu()
|
||||||
|
last_hidden = out.hidden_states[-1] # (B, T, H)
|
||||||
|
mask = enc["attention_mask"].unsqueeze(-1).float()
|
||||||
|
pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
|
||||||
|
pooled = pooled.cpu().tolist()
|
||||||
|
|
||||||
|
for row, vec in zip(probs.tolist(), pooled):
|
||||||
|
p_neg, p_neu, p_pos = row[0], row[1], row[2]
|
||||||
|
label_idx = int(max(range(3), key=lambda k: row[k]))
|
||||||
|
results.append(
|
||||||
|
FinbertOutput(
|
||||||
|
label=_LABELS[label_idx],
|
||||||
|
score=float(p_pos - p_neg),
|
||||||
|
prob_negative=float(p_neg),
|
||||||
|
prob_neutral=float(p_neu),
|
||||||
|
prob_positive=float(p_pos),
|
||||||
|
embedding=[float(x) for x in vec],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def ping() -> dict[str, object]:
|
||||||
|
"""모델 로드 가능 여부 빠르게 확인. 한 번 로드되면 캐시됨."""
|
||||||
|
try:
|
||||||
|
_load()
|
||||||
|
return {"status": "ok", "model": MODEL_NAME, "device": _state["device"]}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return {"status": "failed", "model": MODEL_NAME, "error": str(exc)}
|
||||||
96
backend/app/nlp/score_news.py
Normal file
96
backend/app/nlp/score_news.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""DB news 테이블에서 sentiment_score IS NULL 인 행을 배치로 스코어링.
|
||||||
|
|
||||||
|
refresh_one / daily_batch 에서 뉴스 upsert 직후 호출. 증분만 처리.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
from app.nlp.finbert import score_texts
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoreResult:
|
||||||
|
fetched: int
|
||||||
|
scored: int
|
||||||
|
failed: int
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _vector_literal(vec: list[float]) -> str:
|
||||||
|
"""pgvector 텍스트 리터럴: '[v1,v2,...]' 형식. (:e)::vector 로 캐스팅."""
|
||||||
|
return "[" + ",".join(f"{x:.6f}" for x in vec) + "]"
|
||||||
|
|
||||||
|
|
||||||
|
def score_pending_news(
|
||||||
|
*,
|
||||||
|
batch_size: int = 32,
|
||||||
|
limit: int | None = 500,
|
||||||
|
code: str | None = None,
|
||||||
|
) -> ScoreResult:
|
||||||
|
"""sentiment_score IS NULL 인 news 행에 대해 finbert score + label + embedding 채움.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: finbert inference 배치 크기.
|
||||||
|
limit: 한 번에 처리할 최대 행 수. None 이면 무제한.
|
||||||
|
daily_batch 가 너무 무겁지 않도록 기본 500.
|
||||||
|
code: 특정 종목만 (None 이면 전체 무점수 행).
|
||||||
|
"""
|
||||||
|
eng = get_engine()
|
||||||
|
where = "sentiment_score IS NULL"
|
||||||
|
params: dict[str, Any] = {}
|
||||||
|
if code:
|
||||||
|
where += " AND code = :code"
|
||||||
|
params["code"] = code
|
||||||
|
|
||||||
|
sql_select = (
|
||||||
|
"SELECT id, COALESCE(title, '') || ' ' || COALESCE(body, '') AS txt "
|
||||||
|
f"FROM news WHERE {where} ORDER BY id"
|
||||||
|
)
|
||||||
|
if limit is not None:
|
||||||
|
sql_select += f" LIMIT {int(limit)}"
|
||||||
|
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(text(sql_select), params).all()
|
||||||
|
if not rows:
|
||||||
|
return ScoreResult(fetched=0, scored=0, failed=0)
|
||||||
|
|
||||||
|
ids = [r[0] for r in rows]
|
||||||
|
texts_in = [r[1] for r in rows]
|
||||||
|
|
||||||
|
try:
|
||||||
|
outputs = score_texts(texts_in, batch_size=batch_size)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.exception("score_texts failed")
|
||||||
|
return ScoreResult(fetched=len(rows), scored=0, failed=len(rows), error=str(exc))
|
||||||
|
|
||||||
|
update_sql = text(
|
||||||
|
"UPDATE news SET sentiment_score = :s, sentiment_label = :l, "
|
||||||
|
"embedding = (:e)::vector WHERE id = :id"
|
||||||
|
)
|
||||||
|
scored = 0
|
||||||
|
failed = 0
|
||||||
|
with eng.begin() as conn:
|
||||||
|
for nid, out in zip(ids, outputs):
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
update_sql,
|
||||||
|
{
|
||||||
|
"id": nid,
|
||||||
|
"s": out.score,
|
||||||
|
"l": out.label,
|
||||||
|
"e": _vector_literal(out.embedding),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
scored += 1
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("update news id=%s failed: %s", nid, exc)
|
||||||
|
failed += 1
|
||||||
|
return ScoreResult(fetched=len(rows), scored=scored, failed=failed)
|
||||||
@@ -11,6 +11,7 @@ import time
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.fetch import macro as macro_mod
|
from app.fetch import macro as macro_mod
|
||||||
|
from app.nlp.score_news import score_pending_news
|
||||||
from app.pipelines.refresh_one import refresh_code
|
from app.pipelines.refresh_one import refresh_code
|
||||||
from app.seed.seed_tickers import SEED_TICKERS
|
from app.seed.seed_tickers import SEED_TICKERS
|
||||||
|
|
||||||
@@ -32,11 +33,26 @@ def run_daily_batch() -> dict[str, Any]:
|
|||||||
for m in macros
|
for m in macros
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 시드 종목 refresh 내에서 종목당 200건만 스코어함. 잔여(여러 소스 합쳐
|
||||||
|
# 200건 초과 또는 코드 매핑 안된 google_rss 등)는 여기서 한 번에 mop-up.
|
||||||
|
try:
|
||||||
|
mop = score_pending_news(limit=2000)
|
||||||
|
sentiment_summary: dict[str, Any] = {
|
||||||
|
"status": "ok" if mop.error is None else "failed",
|
||||||
|
"fetched": mop.fetched,
|
||||||
|
"scored": mop.scored,
|
||||||
|
"failed": mop.failed,
|
||||||
|
"error": mop.error,
|
||||||
|
}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
sentiment_summary = {"status": "failed", "error": str(exc)}
|
||||||
|
|
||||||
elapsed = time.time() - start_ts
|
elapsed = time.time() - start_ts
|
||||||
return {
|
return {
|
||||||
"duration_seconds": round(elapsed, 2),
|
"duration_seconds": round(elapsed, 2),
|
||||||
"tickers": reports,
|
"tickers": reports,
|
||||||
"macro": macro_summary,
|
"macro": macro_summary,
|
||||||
|
"sentiment_mop": sentiment_summary,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
165
backend/app/pipelines/match_outcomes.py
Normal file
165
backend/app/pipelines/match_outcomes.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
"""prediction_outcomes 매칭 배치.
|
||||||
|
|
||||||
|
평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의
|
||||||
|
확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, target_date == 오늘인
|
||||||
|
user_triggered=TRUE 예측을 그 종가와 매칭.
|
||||||
|
|
||||||
|
cold-start / 휴장일 대비: 인자로 받은 target_date 의 ohlcv_daily 에 종가가
|
||||||
|
없으면 자연스럽게 skip. 다음 거래일 매칭 잡이 다시 시도하면 그 날짜는
|
||||||
|
여전히 매칭되지 않으므로 (매칭 sql 이 target_date 기준), 영원히 매칭 안되는
|
||||||
|
잘못된 calendar date 예측은 cleanup CLI 로 별도 정리 가능 (Phase 7).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# direction_hit 판정 시 ±0.3% 이내는 flat. (features 의 FLAT_BAND 와 동일)
|
||||||
|
FLAT_BAND = 0.003
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchSummary:
|
||||||
|
target_date: str
|
||||||
|
candidates: int
|
||||||
|
matched: int
|
||||||
|
skipped_no_actual: int
|
||||||
|
already_resolved: int
|
||||||
|
|
||||||
|
|
||||||
|
def _direction_label(ret: float) -> str:
|
||||||
|
if ret > FLAT_BAND:
|
||||||
|
return "up"
|
||||||
|
if ret < -FLAT_BAND:
|
||||||
|
return "down"
|
||||||
|
return "flat"
|
||||||
|
|
||||||
|
|
||||||
|
def match_for_date(d: date) -> MatchSummary:
|
||||||
|
"""target_date == d 인 user_triggered=TRUE 예측을 매칭."""
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.begin() as conn:
|
||||||
|
# 매칭 대상 예측 + 매칭 안 됐는지 확인.
|
||||||
|
candidate_rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT p.id, p.code, p.base_date, p.horizon, p.point_forecast,
|
||||||
|
p.direction, p.model
|
||||||
|
FROM predictions p
|
||||||
|
LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id
|
||||||
|
WHERE p.target_date = :d
|
||||||
|
AND p.user_triggered = TRUE
|
||||||
|
AND po.prediction_id IS NULL
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"d": d},
|
||||||
|
).all()
|
||||||
|
candidates = len(candidate_rows)
|
||||||
|
if not candidates:
|
||||||
|
return MatchSummary(str(d), 0, 0, 0, 0)
|
||||||
|
|
||||||
|
# 종목별로 actual close 조회 (한번에 batch).
|
||||||
|
codes = list({r[1] for r in candidate_rows})
|
||||||
|
actual_map: dict[tuple[str, date], float] = {}
|
||||||
|
for code in codes:
|
||||||
|
row = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"
|
||||||
|
),
|
||||||
|
{"c": code, "d": d},
|
||||||
|
).first()
|
||||||
|
if row and row[0] is not None:
|
||||||
|
actual_map[(code, d)] = float(row[0])
|
||||||
|
|
||||||
|
# base_close (각 예측의 base_date 종가) 도 필요 — direction 판정용.
|
||||||
|
base_close_map: dict[tuple[str, date], float] = {}
|
||||||
|
for pid, code, base_date, *_ in candidate_rows:
|
||||||
|
key = (code, base_date)
|
||||||
|
if key in base_close_map:
|
||||||
|
continue
|
||||||
|
row = conn.execute(
|
||||||
|
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
||||||
|
{"c": code, "d": base_date},
|
||||||
|
).first()
|
||||||
|
if row and row[0] is not None:
|
||||||
|
base_close_map[key] = float(row[0])
|
||||||
|
|
||||||
|
matched = 0
|
||||||
|
skipped = 0
|
||||||
|
already = 0
|
||||||
|
for pid, code, base_date, horizon, point_forecast, pred_dir, model in candidate_rows:
|
||||||
|
actual = actual_map.get((code, d))
|
||||||
|
base_close = base_close_map.get((code, base_date))
|
||||||
|
if actual is None or base_close is None:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
actual_ret = actual / base_close - 1.0
|
||||||
|
actual_dir = _direction_label(actual_ret)
|
||||||
|
dir_hit = (pred_dir == actual_dir)
|
||||||
|
abs_err = abs(float(point_forecast) - actual) if point_forecast is not None else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO prediction_outcomes
|
||||||
|
(prediction_id, code, target_date, horizon, model,
|
||||||
|
predicted_close, actual_close, actual_return, direction_hit, abs_error)
|
||||||
|
VALUES
|
||||||
|
(:pid, :code, :d, :h, :m, :pc, :ac, :ar, :dh, :ae)
|
||||||
|
ON CONFLICT (prediction_id) DO NOTHING
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"pid": pid,
|
||||||
|
"code": code,
|
||||||
|
"d": d,
|
||||||
|
"h": horizon,
|
||||||
|
"m": model,
|
||||||
|
"pc": float(point_forecast) if point_forecast is not None else None,
|
||||||
|
"ac": actual,
|
||||||
|
"ar": float(actual_ret),
|
||||||
|
"dh": bool(dir_hit),
|
||||||
|
"ae": float(abs_err) if abs_err is not None else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
matched += 1
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("match insert failed pid=%s: %s", pid, exc)
|
||||||
|
already += 1
|
||||||
|
|
||||||
|
return MatchSummary(
|
||||||
|
target_date=str(d),
|
||||||
|
candidates=candidates,
|
||||||
|
matched=matched,
|
||||||
|
skipped_no_actual=skipped,
|
||||||
|
already_resolved=already,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def match_today() -> dict[str, Any]:
|
||||||
|
"""평일 16:30 KST 호출용. target_date == today (KST) 인 행 매칭."""
|
||||||
|
from datetime import datetime, timezone, timedelta as td
|
||||||
|
|
||||||
|
kst = timezone(td(hours=9))
|
||||||
|
today = datetime.now(kst).date()
|
||||||
|
summary = match_for_date(today)
|
||||||
|
return {
|
||||||
|
"today": str(today),
|
||||||
|
"summary": summary.__dict__,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
out = match_today()
|
||||||
|
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|
||||||
138
backend/app/pipelines/predict_one.py
Normal file
138
backend/app/pipelines/predict_one.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""On-demand 예측 + DB 적재.
|
||||||
|
|
||||||
|
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
|
||||||
|
- ensemble.predict() 로 horizons (1,3,5) 결과 계산
|
||||||
|
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
|
||||||
|
(대충 calendar 일로 +h * 1.4 — KRX 영업일 추정. Phase 4 단순화: base_date + h 영업일은
|
||||||
|
ohlcv 상의 다음 h 거래일이 아닌, "거래일 카운트" 대신 단순 calendar+h 로 저장하고
|
||||||
|
매칭 잡에서 ohlcv_daily 에 그 날짜 행이 있는지로 자연 보정.)
|
||||||
|
|
||||||
|
대안 정확도 위해: 매칭 잡은 "예측의 target_date 이 오늘"인 행을 그날 종가와 비교.
|
||||||
|
calendar date 가 비거래일이면 매칭이 안 되니, 매칭 잡은 매일 실행되어 모일 때 처리.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict
|
||||||
|
from datetime import date, datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
from app.models.ensemble import EnsemblePrediction, predict as ensemble_predict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KST = timezone(timedelta(hours=9))
|
||||||
|
|
||||||
|
|
||||||
|
def _next_trading_target(base_date: date, horizon: int) -> date:
|
||||||
|
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
|
||||||
|
d = base_date
|
||||||
|
added = 0
|
||||||
|
while added < horizon:
|
||||||
|
d = d + timedelta(days=1)
|
||||||
|
if d.weekday() < 5: # 0..4 = Mon..Fri
|
||||||
|
added += 1
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def _last_trading_date(code: str) -> date | None:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
text("SELECT MAX(date) FROM ohlcv_daily WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
return row[0] if row and row[0] else None
|
||||||
|
|
||||||
|
|
||||||
|
def predict_and_store(
|
||||||
|
code: str,
|
||||||
|
*,
|
||||||
|
horizons: tuple[int, ...] = (1, 3, 5),
|
||||||
|
user_triggered: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""앙상블 예측 실행 + predictions 테이블 적재. 결과 JSON-serializable dict 반환."""
|
||||||
|
base_date = _last_trading_date(code)
|
||||||
|
if base_date is None:
|
||||||
|
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
|
||||||
|
|
||||||
|
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
|
||||||
|
now = datetime.now(KST)
|
||||||
|
|
||||||
|
eng = get_engine()
|
||||||
|
saved_ids: list[int] = []
|
||||||
|
with eng.begin() as conn:
|
||||||
|
for step in pred.steps:
|
||||||
|
target_date = _next_trading_target(base_date, step.horizon)
|
||||||
|
features_snap = {
|
||||||
|
"base_close": pred.base_close,
|
||||||
|
"sources_used": pred.sources_used,
|
||||||
|
"direction": step.direction,
|
||||||
|
}
|
||||||
|
row = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO predictions
|
||||||
|
(code, predicted_at, base_date, target_date, horizon, model,
|
||||||
|
direction, prob_up, prob_flat, prob_down, expected_return,
|
||||||
|
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
|
||||||
|
VALUES
|
||||||
|
(:code, :predicted_at, :base_date, :target_date, :horizon, 'ensemble',
|
||||||
|
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
|
||||||
|
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
|
||||||
|
ON CONFLICT (code, base_date, target_date, horizon, model)
|
||||||
|
DO UPDATE SET
|
||||||
|
predicted_at = EXCLUDED.predicted_at,
|
||||||
|
direction = EXCLUDED.direction,
|
||||||
|
prob_up = EXCLUDED.prob_up,
|
||||||
|
prob_flat = EXCLUDED.prob_flat,
|
||||||
|
prob_down = EXCLUDED.prob_down,
|
||||||
|
expected_return = EXCLUDED.expected_return,
|
||||||
|
point_forecast = EXCLUDED.point_forecast,
|
||||||
|
ci_low = EXCLUDED.ci_low,
|
||||||
|
ci_high = EXCLUDED.ci_high,
|
||||||
|
features_snapshot = EXCLUDED.features_snapshot,
|
||||||
|
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
|
||||||
|
RETURNING id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"code": code,
|
||||||
|
"predicted_at": now,
|
||||||
|
"base_date": base_date,
|
||||||
|
"target_date": target_date,
|
||||||
|
"horizon": step.horizon,
|
||||||
|
"direction": step.direction,
|
||||||
|
"p_up": step.prob_up,
|
||||||
|
"p_fl": step.prob_flat,
|
||||||
|
"p_dn": step.prob_down,
|
||||||
|
"exp_ret": step.expected_return,
|
||||||
|
"point": step.point_close,
|
||||||
|
"lo": step.ci_low,
|
||||||
|
"hi": step.ci_high,
|
||||||
|
"feats": json.dumps(features_snap),
|
||||||
|
"ut": user_triggered,
|
||||||
|
},
|
||||||
|
).first()
|
||||||
|
if row:
|
||||||
|
saved_ids.append(int(row[0]))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": code,
|
||||||
|
"base_date": str(base_date),
|
||||||
|
"base_close": pred.base_close,
|
||||||
|
"sources_used": pred.sources_used,
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
**asdict(s),
|
||||||
|
"target_date": str(_next_trading_target(base_date, s.horizon)),
|
||||||
|
}
|
||||||
|
for s in pred.steps
|
||||||
|
],
|
||||||
|
"saved_prediction_ids": saved_ids,
|
||||||
|
"user_triggered": user_triggered,
|
||||||
|
}
|
||||||
@@ -36,6 +36,7 @@ class RefreshReport:
|
|||||||
dart: SourceStatus
|
dart: SourceStatus
|
||||||
naver_news: SourceStatus
|
naver_news: SourceStatus
|
||||||
google_rss: SourceStatus
|
google_rss: SourceStatus
|
||||||
|
finbert: SourceStatus
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
out: dict[str, Any] = {"code": self.code}
|
out: dict[str, Any] = {"code": self.code}
|
||||||
@@ -46,6 +47,7 @@ class RefreshReport:
|
|||||||
"dart",
|
"dart",
|
||||||
"naver_news",
|
"naver_news",
|
||||||
"google_rss",
|
"google_rss",
|
||||||
|
"finbert",
|
||||||
):
|
):
|
||||||
v: SourceStatus = getattr(self, f)
|
v: SourceStatus = getattr(self, f)
|
||||||
out[f] = asdict(v)
|
out[f] = asdict(v)
|
||||||
@@ -132,6 +134,25 @@ def _google_rss(code: str, name: str) -> SourceStatus:
|
|||||||
return SourceStatus(status="failed", error=str(exc))
|
return SourceStatus(status="failed", error=str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
def _finbert(code: str) -> SourceStatus:
|
||||||
|
"""방금 upsert 된 뉴스 중 sentiment_score 가 비어있는 행을 KR-FinBERT 로 스코어."""
|
||||||
|
try:
|
||||||
|
from app.nlp.score_news import score_pending_news
|
||||||
|
|
||||||
|
# 한 종목에 대해 신규 뉴스가 매우 많아도 200건으로 컷.
|
||||||
|
# daily_batch 끝에서 잔여분을 별도로 mop-up 한다.
|
||||||
|
res = score_pending_news(code=code, limit=200)
|
||||||
|
return SourceStatus(
|
||||||
|
status="ok" if res.error is None else "failed",
|
||||||
|
inserted=res.scored,
|
||||||
|
skipped=res.failed,
|
||||||
|
extra={"fetched": res.fetched},
|
||||||
|
error=res.error,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return SourceStatus(status="failed", error=str(exc))
|
||||||
|
|
||||||
|
|
||||||
def refresh_code(code: str, name: str, *, lookback_days: int = 7) -> RefreshReport:
|
def refresh_code(code: str, name: str, *, lookback_days: int = 7) -> RefreshReport:
|
||||||
"""단기 갱신 (daily_batch 용). 최근 lookback_days 만 가져온다."""
|
"""단기 갱신 (daily_batch 용). 최근 lookback_days 만 가져온다."""
|
||||||
end = date.today()
|
end = date.today()
|
||||||
@@ -144,4 +165,5 @@ def refresh_code(code: str, name: str, *, lookback_days: int = 7) -> RefreshRepo
|
|||||||
dart=_dart(code, start, end),
|
dart=_dart(code, start, end),
|
||||||
naver_news=_naver_news(code),
|
naver_news=_naver_news(code),
|
||||||
google_rss=_google_rss(code, name),
|
google_rss=_google_rss(code, name),
|
||||||
|
finbert=_finbert(code),
|
||||||
)
|
)
|
||||||
|
|||||||
122
backend/app/pipelines/retrain_weekly.py
Normal file
122
backend/app/pipelines/retrain_weekly.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""주간 재학습 + 앙상블 가중치 보정.
|
||||||
|
|
||||||
|
일요일 02:00 KST 실행:
|
||||||
|
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
|
||||||
|
2. 최근 30일 prediction_outcomes 의 model 별 hit_rate 산출, model_performance 적재.
|
||||||
|
3. 같은 30일 윈도우에서 chronos vs lgbm hit_rate 로 ensemble_weights 갱신.
|
||||||
|
(방법: w_chronos = clamp(0.1, hr_c / (hr_c + hr_l), 0.9), w_lgbm = 1 - w_chronos.
|
||||||
|
hit_rate 데이터가 부족하면 default 0.6/0.4 유지.)
|
||||||
|
|
||||||
|
지금은 'ensemble' 모델 단일 종류로 predictions 가 쌓이므로, 가중치 보정은
|
||||||
|
chronos 단독 시뮬레이션 / lgbm 단독 시뮬레이션 hit_rate 비교가 진정한 방식인데,
|
||||||
|
Phase 4 단순화: 'ensemble' 의 종합 hit_rate 만 model_performance 에 기록하고
|
||||||
|
가중치는 default 유지. 진짜 비교는 Phase 7 (chronos 단독 + lgbm 단독 예측을
|
||||||
|
shadow 로 같이 적재하는 구조) 로 확장.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
from app.models.lgbm import train_one
|
||||||
|
from app.seed.seed_tickers import SEED_TICKERS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HORIZONS = (1, 3, 5)
|
||||||
|
WINDOW_DAYS = 30
|
||||||
|
|
||||||
|
|
||||||
|
def retrain_all() -> list[dict[str, Any]]:
|
||||||
|
"""시드 10종목 × horizons 학습."""
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for t in SEED_TICKERS:
|
||||||
|
for h in HORIZONS:
|
||||||
|
try:
|
||||||
|
res = train_one(t.code, h)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.exception("train_one failed for %s h=%d", t.code, h)
|
||||||
|
res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)}
|
||||||
|
out.append(res)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def record_performance(as_of: date) -> list[dict[str, Any]]:
|
||||||
|
"""최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별
|
||||||
|
hit_rate / mae 산출, model_performance 에 upsert."""
|
||||||
|
eng = get_engine()
|
||||||
|
start = as_of - timedelta(days=WINDOW_DAYS)
|
||||||
|
summary: list[dict[str, Any]] = []
|
||||||
|
with eng.begin() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT code, model, horizon,
|
||||||
|
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||||
|
AVG(abs_error) AS mae,
|
||||||
|
COUNT(*) AS n
|
||||||
|
FROM prediction_outcomes
|
||||||
|
WHERE resolved_at >= :start
|
||||||
|
GROUP BY code, model, horizon
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"start": start},
|
||||||
|
).all()
|
||||||
|
for code, model, horizon, hr, mae, n in rows:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO model_performance
|
||||||
|
(code, model, window_days, as_of, hit_rate, mae, sample_count)
|
||||||
|
VALUES (:c, :m, :w, :as_of, :hr, :mae, :n)
|
||||||
|
ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET
|
||||||
|
hit_rate = EXCLUDED.hit_rate,
|
||||||
|
mae = EXCLUDED.mae,
|
||||||
|
sample_count = EXCLUDED.sample_count
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"c": code,
|
||||||
|
"m": model,
|
||||||
|
"w": WINDOW_DAYS,
|
||||||
|
"as_of": as_of,
|
||||||
|
"hr": float(hr) if hr is not None else None,
|
||||||
|
"mae": float(mae) if mae is not None else None,
|
||||||
|
"n": int(n),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
summary.append(
|
||||||
|
{
|
||||||
|
"code": code,
|
||||||
|
"model": model,
|
||||||
|
"horizon": horizon,
|
||||||
|
"hit_rate": float(hr) if hr is not None else None,
|
||||||
|
"mae": float(mae) if mae is not None else None,
|
||||||
|
"n": int(n),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def run_weekly() -> dict[str, Any]:
|
||||||
|
"""일요일 02:00 KST 호출 entry-point."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
kst = timezone(timedelta(hours=9))
|
||||||
|
as_of = datetime.now(kst).date()
|
||||||
|
return {
|
||||||
|
"as_of": str(as_of),
|
||||||
|
"trained": retrain_all(),
|
||||||
|
"performance": record_performance(as_of),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
out = run_weekly()
|
||||||
|
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|
||||||
@@ -22,6 +22,8 @@ from apscheduler.triggers.cron import CronTrigger
|
|||||||
from pytz import timezone
|
from pytz import timezone
|
||||||
|
|
||||||
from app.pipelines.daily_batch import run_daily_batch
|
from app.pipelines.daily_batch import run_daily_batch
|
||||||
|
from app.pipelines.match_outcomes import match_today
|
||||||
|
from app.pipelines.retrain_weekly import run_weekly
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
KST = timezone("Asia/Seoul")
|
KST = timezone("Asia/Seoul")
|
||||||
@@ -34,15 +36,34 @@ def start_scheduler() -> BackgroundScheduler:
|
|||||||
if _scheduler:
|
if _scheduler:
|
||||||
return _scheduler
|
return _scheduler
|
||||||
_scheduler = BackgroundScheduler(timezone=KST)
|
_scheduler = BackgroundScheduler(timezone=KST)
|
||||||
|
# 16:00 평일: 시드 10종목 EOD/뉴스/공시/거시 갱신
|
||||||
_scheduler.add_job(
|
_scheduler.add_job(
|
||||||
run_daily_batch,
|
run_daily_batch,
|
||||||
CronTrigger(hour=16, minute=0, timezone=KST),
|
CronTrigger(day_of_week="mon-fri", hour=16, minute=0, timezone=KST),
|
||||||
id="daily_batch_16",
|
id="daily_batch_16",
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
max_instances=1,
|
max_instances=1,
|
||||||
)
|
)
|
||||||
|
# 16:30 평일: prediction_outcomes 매칭 배치
|
||||||
|
_scheduler.add_job(
|
||||||
|
match_today,
|
||||||
|
CronTrigger(day_of_week="mon-fri", hour=16, minute=30, timezone=KST),
|
||||||
|
id="match_outcomes_1630",
|
||||||
|
replace_existing=True,
|
||||||
|
max_instances=1,
|
||||||
|
)
|
||||||
|
# 일요일 02:00: LGBM 재학습 + 성능 기록
|
||||||
|
_scheduler.add_job(
|
||||||
|
run_weekly,
|
||||||
|
CronTrigger(day_of_week="sun", hour=2, minute=0, timezone=KST),
|
||||||
|
id="retrain_weekly_sun_0200",
|
||||||
|
replace_existing=True,
|
||||||
|
max_instances=1,
|
||||||
|
)
|
||||||
_scheduler.start()
|
_scheduler.start()
|
||||||
logger.info("scheduler started (daily_batch @ 16:00 KST)")
|
logger.info(
|
||||||
|
"scheduler started: daily_batch(16:00 mon-fri), match_outcomes(16:30 mon-fri), retrain_weekly(sun 02:00) KST"
|
||||||
|
)
|
||||||
return _scheduler
|
return _scheduler
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,9 +31,12 @@ dependencies = [
|
|||||||
"transformers==4.41.2",
|
"transformers==4.41.2",
|
||||||
"tokenizers==0.19.1",
|
"tokenizers==0.19.1",
|
||||||
"sentencepiece==0.2.0",
|
"sentencepiece==0.2.0",
|
||||||
|
"accelerate==0.30.1",
|
||||||
|
"chronos-forecasting==1.4.1",
|
||||||
"scikit-learn==1.5.0",
|
"scikit-learn==1.5.0",
|
||||||
"lightgbm==4.3.0",
|
"lightgbm==4.3.0",
|
||||||
"ta==0.11.0",
|
"ta==0.11.0",
|
||||||
|
"joblib==1.4.2",
|
||||||
|
|
||||||
# scheduler
|
# scheduler
|
||||||
"apscheduler==3.10.4",
|
"apscheduler==3.10.4",
|
||||||
@@ -63,3 +66,19 @@ include = ["app*"]
|
|||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
target-version = "py311"
|
target-version = "py311"
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
strict_optional = true
|
||||||
|
warn_unused_ignores = true
|
||||||
|
warn_redundant_casts = true
|
||||||
|
warn_unreachable = true
|
||||||
|
ignore_missing_imports = true # 3rd-party stub 부족 무시 (pykrx, chronos, ta, feedparser ...)
|
||||||
|
no_implicit_optional = true
|
||||||
|
show_error_codes = true
|
||||||
|
exclude = ["build", "dist", ".venv"]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = ["app.*"]
|
||||||
|
disallow_untyped_defs = false # 점진적 도입. 핵심 모듈만 strict 로 올릴 예정.
|
||||||
|
check_untyped_defs = true
|
||||||
|
|||||||
6
web/.eslintrc.json
Normal file
6
web/.eslintrc.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"extends": "next/core-web-vitals",
|
||||||
|
"rules": {
|
||||||
|
"react-hooks/exhaustive-deps": "warn"
|
||||||
|
}
|
||||||
|
}
|
||||||
99
web/app/[code]/page.tsx
Normal file
99
web/app/[code]/page.tsx
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import Link from "next/link";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { MetricsPanel } from "../../components/MetricsPanel";
|
||||||
|
import { NewsList } from "../../components/NewsList";
|
||||||
|
import { PredictionPanel } from "../../components/PredictionPanel";
|
||||||
|
import { StockChart } from "../../components/StockChart";
|
||||||
|
import {
|
||||||
|
api,
|
||||||
|
type ChartPayload,
|
||||||
|
type LatestPredictionResponse,
|
||||||
|
} from "../../lib/api";
|
||||||
|
|
||||||
|
export default function CodePage({ params }: { params: { code: string } }) {
|
||||||
|
const { code } = params;
|
||||||
|
const [chart, setChart] = useState<ChartPayload | null>(null);
|
||||||
|
const [prediction, setPrediction] = useState<LatestPredictionResponse | null>(null);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
const [days, setDays] = useState(180);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let alive = true;
|
||||||
|
setErr(null);
|
||||||
|
setChart(null);
|
||||||
|
api
|
||||||
|
.getChart(code, days)
|
||||||
|
.then((c) => {
|
||||||
|
if (alive) setChart(c);
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
if (alive) setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
});
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
};
|
||||||
|
}, [code, days]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let alive = true;
|
||||||
|
api
|
||||||
|
.latestPrediction(code)
|
||||||
|
.then((r) => {
|
||||||
|
if (alive && r.found) setPrediction(r);
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
// 예측 이력 없는 경우는 무시.
|
||||||
|
});
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
};
|
||||||
|
}, [code]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<main className="mx-auto max-w-5xl px-6 py-10">
|
||||||
|
<div className="mb-4 flex items-center justify-between">
|
||||||
|
<Link href="/" className="text-xs text-zinc-500 hover:text-zinc-300">
|
||||||
|
← 검색으로
|
||||||
|
</Link>
|
||||||
|
<select
|
||||||
|
value={days}
|
||||||
|
onChange={(e) => setDays(Number(e.target.value))}
|
||||||
|
className="rounded-md border border-zinc-700 bg-zinc-900 px-2 py-1 text-xs"
|
||||||
|
>
|
||||||
|
<option value={60}>최근 3개월</option>
|
||||||
|
<option value={180}>최근 6개월</option>
|
||||||
|
<option value={365}>최근 1년</option>
|
||||||
|
<option value={1095}>최근 3년</option>
|
||||||
|
</select>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{chart && (
|
||||||
|
<div className="mb-4">
|
||||||
|
<h1 className="text-2xl font-semibold text-zinc-100">
|
||||||
|
{chart.name}{" "}
|
||||||
|
<span className="text-sm font-normal text-zinc-500">
|
||||||
|
{chart.code} · {chart.market}
|
||||||
|
</span>
|
||||||
|
</h1>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{err && <div className="mb-4 text-sm text-red-400">차트 로딩 실패: {err}</div>}
|
||||||
|
|
||||||
|
{chart && (
|
||||||
|
<>
|
||||||
|
<StockChart chart={chart} prediction={prediction} />
|
||||||
|
<div className="mt-6">
|
||||||
|
<PredictionPanel code={code} initial={prediction} onResult={setPrediction} />
|
||||||
|
</div>
|
||||||
|
<div className="mt-6 grid gap-6 md:grid-cols-2">
|
||||||
|
<MetricsPanel code={code} />
|
||||||
|
<NewsList code={code} />
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</main>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,13 +1,32 @@
|
|||||||
|
import { SearchBox } from "../components/SearchBox";
|
||||||
|
|
||||||
export default function HomePage() {
|
export default function HomePage() {
|
||||||
return (
|
return (
|
||||||
<main className="mx-auto max-w-3xl px-6 py-16">
|
<main className="mx-auto max-w-3xl px-6 py-16">
|
||||||
<h1 className="text-3xl font-bold tracking-tight">Stock Chart Site</h1>
|
<h1 className="text-3xl font-bold tracking-tight">Stock Chart Site</h1>
|
||||||
<p className="mt-3 text-sm text-zinc-400">
|
<p className="mt-2 text-sm text-zinc-400">
|
||||||
Phase 0 scaffold. 종목 검색 UI는 Phase 6에서 추가됩니다.
|
종목을 검색해 현재 차트를 보고, <b>예상차트 보기</b> 버튼으로 Chronos + LightGBM
|
||||||
|
앙상블의 단기(1·3·5거래일) 예측을 차트에 이어 붙입니다.
|
||||||
</p>
|
</p>
|
||||||
<div className="mt-8 rounded-md border border-zinc-800 bg-zinc-900/50 p-4 text-sm">
|
|
||||||
<div className="font-medium">Backend health</div>
|
<div className="mt-8">
|
||||||
<code className="mt-2 block text-zinc-400">GET {process.env.NEXT_PUBLIC_API_BASE}/health</code>
|
<SearchBox autoFocus />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="mt-12 grid gap-3 text-xs text-zinc-500 sm:grid-cols-2">
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/30 p-4">
|
||||||
|
<div className="font-medium text-zinc-300">학습 대상 10종목</div>
|
||||||
|
<div className="mt-1">
|
||||||
|
삼성전자, SK하이닉스, 에코프로비엠, 한미반도체, 두산에너빌리티, 한화에어로스페이스,
|
||||||
|
HD현대중공업, NAVER, KT&G, 한국가스공사
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/30 p-4">
|
||||||
|
<div className="font-medium text-zinc-300">매칭/재학습</div>
|
||||||
|
<div className="mt-1">
|
||||||
|
평일 16:30 KST 에 예측과 실제 종가 매칭, 일요일 02:00 KST 에 LGBM 재학습.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</main>
|
</main>
|
||||||
);
|
);
|
||||||
|
|||||||
64
web/components/MetricsPanel.tsx
Normal file
64
web/components/MetricsPanel.tsx
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { api, type MetricsResponse } from "../lib/api";
|
||||||
|
|
||||||
|
export function MetricsPanel({ code }: { code: string }) {
|
||||||
|
const [m, setM] = useState<MetricsResponse | null>(null);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let alive = true;
|
||||||
|
api
|
||||||
|
.metrics(code, 30)
|
||||||
|
.then((r) => {
|
||||||
|
if (alive) setM(r);
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
if (alive) setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
});
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
};
|
||||||
|
}, [code]);
|
||||||
|
|
||||||
|
if (err) return <div className="text-xs text-red-400">메트릭 로딩 실패: {err}</div>;
|
||||||
|
if (!m) return <div className="text-xs text-zinc-500">메트릭 로딩 중…</div>;
|
||||||
|
|
||||||
|
const rows = m.by_model_horizon ?? [];
|
||||||
|
if (!rows.length) {
|
||||||
|
return (
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
최근 30일 매칭된 예측 결과가 아직 없습니다. (매칭 배치는 평일 16:30 KST 에 실행)
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/40 p-4">
|
||||||
|
<div className="mb-2 text-sm font-medium text-zinc-200">최근 30일 모델 성능</div>
|
||||||
|
<table className="w-full text-left text-sm">
|
||||||
|
<thead className="text-xs text-zinc-500">
|
||||||
|
<tr>
|
||||||
|
<th className="py-1">모델</th>
|
||||||
|
<th>+거래일</th>
|
||||||
|
<th>표본 수</th>
|
||||||
|
<th>방향 적중률</th>
|
||||||
|
<th>평균 절대오차 (MAE)</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody className="divide-y divide-zinc-800">
|
||||||
|
{rows.map((r, i) => (
|
||||||
|
<tr key={i}>
|
||||||
|
<td className="py-1">{r.model}</td>
|
||||||
|
<td>+{r.horizon}</td>
|
||||||
|
<td>{r.n}</td>
|
||||||
|
<td>{r.hit_rate != null ? `${(r.hit_rate * 100).toFixed(1)}%` : "-"}</td>
|
||||||
|
<td>{r.mae != null ? r.mae.toLocaleString(undefined, { maximumFractionDigits: 1 }) : "-"}</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
77
web/components/NewsList.tsx
Normal file
77
web/components/NewsList.tsx
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { api, type NewsResponse } from "../lib/api";
|
||||||
|
|
||||||
|
export function NewsList({ code }: { code: string }) {
|
||||||
|
const [data, setData] = useState<NewsResponse | null>(null);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let alive = true;
|
||||||
|
api
|
||||||
|
.news(code, 20)
|
||||||
|
.then((r) => {
|
||||||
|
if (alive) setData(r);
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
if (alive) setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
});
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
};
|
||||||
|
}, [code]);
|
||||||
|
|
||||||
|
if (err) return <div className="text-xs text-red-400">뉴스 로딩 실패: {err}</div>;
|
||||||
|
if (!data) return <div className="text-xs text-zinc-500">뉴스 로딩 중…</div>;
|
||||||
|
if (!data.items.length)
|
||||||
|
return <div className="text-xs text-zinc-500">최근 뉴스 없음</div>;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/40 p-4">
|
||||||
|
<div className="mb-2 text-sm font-medium text-zinc-200">최근 뉴스/공시</div>
|
||||||
|
<ul className="divide-y divide-zinc-800">
|
||||||
|
{data.items.map((n, i) => (
|
||||||
|
<li key={i} className="py-2">
|
||||||
|
<a
|
||||||
|
href={n.url}
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
className="block hover:bg-zinc-800/40"
|
||||||
|
>
|
||||||
|
<div className="text-sm text-zinc-100 line-clamp-2">{n.title}</div>
|
||||||
|
<div className="mt-0.5 flex items-center gap-2 text-xs text-zinc-500">
|
||||||
|
<span>{n.source}</span>
|
||||||
|
{n.published_at && <span>· {formatDate(n.published_at)}</span>}
|
||||||
|
{n.sentiment_label && (
|
||||||
|
<span className={sentimentColor(n.sentiment_label)}>
|
||||||
|
· {n.sentiment_label} {n.sentiment_score != null ? `(${n.sentiment_score.toFixed(2)})` : ""}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</a>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function sentimentColor(l: string): string {
|
||||||
|
if (l === "positive") return "text-emerald-400";
|
||||||
|
if (l === "negative") return "text-red-400";
|
||||||
|
return "text-zinc-400";
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatDate(iso: string): string {
|
||||||
|
try {
|
||||||
|
const d = new Date(iso);
|
||||||
|
return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())} ${pad(d.getHours())}:${pad(d.getMinutes())}`;
|
||||||
|
} catch {
|
||||||
|
return iso;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function pad(n: number): string {
|
||||||
|
return n < 10 ? `0${n}` : `${n}`;
|
||||||
|
}
|
||||||
157
web/components/PredictionPanel.tsx
Normal file
157
web/components/PredictionPanel.tsx
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState } from "react";
|
||||||
|
import {
|
||||||
|
api,
|
||||||
|
type LatestPredictionResponse,
|
||||||
|
type LatestPredictionStep,
|
||||||
|
type PredictResponse,
|
||||||
|
} from "../lib/api";
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
code: string;
|
||||||
|
initial?: LatestPredictionResponse | null;
|
||||||
|
onResult: (pred: LatestPredictionResponse) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
function normalizeFromPredictResponse(
|
||||||
|
code: string,
|
||||||
|
resp: PredictResponse,
|
||||||
|
): LatestPredictionResponse {
|
||||||
|
const steps: LatestPredictionStep[] = resp.steps.map((s) => ({
|
||||||
|
predicted_at: null,
|
||||||
|
target_date: s.target_date ?? "",
|
||||||
|
horizon: s.horizon,
|
||||||
|
direction: s.direction,
|
||||||
|
prob_up: s.prob_up,
|
||||||
|
prob_flat: s.prob_flat,
|
||||||
|
prob_down: s.prob_down,
|
||||||
|
expected_return: s.expected_return,
|
||||||
|
point_close: s.point_close,
|
||||||
|
ci_low: s.ci_low,
|
||||||
|
ci_high: s.ci_high,
|
||||||
|
user_triggered: resp.user_triggered,
|
||||||
|
features_snapshot: null,
|
||||||
|
}));
|
||||||
|
return {
|
||||||
|
code,
|
||||||
|
found: true,
|
||||||
|
base_date: resp.base_date,
|
||||||
|
base_close: resp.base_close,
|
||||||
|
steps,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function PredictionPanel({ code, initial, onResult }: Props) {
|
||||||
|
const [pred, setPred] = useState<LatestPredictionResponse | null>(initial ?? null);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
|
||||||
|
async function runPredict() {
|
||||||
|
setLoading(true);
|
||||||
|
setErr(null);
|
||||||
|
try {
|
||||||
|
const r = await api.predict(code);
|
||||||
|
const normalized = normalizeFromPredictResponse(code, r);
|
||||||
|
setPred(normalized);
|
||||||
|
onResult(normalized);
|
||||||
|
} catch (e) {
|
||||||
|
setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const steps = pred?.steps ?? [];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/40 p-4">
|
||||||
|
<div className="mb-3 flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<div className="text-sm font-medium text-zinc-200">예측 (Chronos + LightGBM 앙상블)</div>
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
클릭한 종목은 자동 저장 후 다음 거래일 장 종료 시 실제 가격과 비교됩니다.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={runPredict}
|
||||||
|
disabled={loading}
|
||||||
|
className="rounded-md bg-emerald-700 px-4 py-2 text-sm font-medium text-white hover:bg-emerald-600 disabled:opacity-50"
|
||||||
|
>
|
||||||
|
{loading ? "예측 중…" : pred?.found ? "다시 예측" : "예상차트 보기"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{err && <div className="mb-3 text-xs text-red-400">에러: {err}</div>}
|
||||||
|
|
||||||
|
{pred?.found ? (
|
||||||
|
<div>
|
||||||
|
<div className="mb-2 text-xs text-zinc-500">
|
||||||
|
기준일 {pred.base_date} · 기준종가{" "}
|
||||||
|
{pred.base_close != null ? pred.base_close.toLocaleString() : "-"}
|
||||||
|
</div>
|
||||||
|
<table className="w-full text-left text-sm">
|
||||||
|
<thead className="text-xs text-zinc-500">
|
||||||
|
<tr>
|
||||||
|
<th className="py-1">+거래일</th>
|
||||||
|
<th>매칭일</th>
|
||||||
|
<th>방향</th>
|
||||||
|
<th>P(up/flat/down)</th>
|
||||||
|
<th>기대수익</th>
|
||||||
|
<th>예측 종가</th>
|
||||||
|
<th>q10~q90</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody className="divide-y divide-zinc-800">
|
||||||
|
{steps.map((s) => (
|
||||||
|
<tr key={s.horizon}>
|
||||||
|
<td className="py-2">+{s.horizon}</td>
|
||||||
|
<td className="text-xs text-zinc-400">{s.target_date}</td>
|
||||||
|
<td>
|
||||||
|
<span
|
||||||
|
className={
|
||||||
|
s.direction === "up"
|
||||||
|
? "text-emerald-400"
|
||||||
|
: s.direction === "down"
|
||||||
|
? "text-red-400"
|
||||||
|
: "text-zinc-300"
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{s.direction}
|
||||||
|
</span>
|
||||||
|
</td>
|
||||||
|
<td className="text-xs text-zinc-300">
|
||||||
|
{fmtPct(s.prob_up)} / {fmtPct(s.prob_flat)} / {fmtPct(s.prob_down)}
|
||||||
|
</td>
|
||||||
|
<td>{fmtSignedPct(s.expected_return)}</td>
|
||||||
|
<td>{s.point_close != null ? s.point_close.toLocaleString() : "-"}</td>
|
||||||
|
<td className="text-xs text-zinc-400">
|
||||||
|
{s.ci_low != null ? s.ci_low.toLocaleString() : "-"} ~{" "}
|
||||||
|
{s.ci_high != null ? s.ci_high.toLocaleString() : "-"}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
아직 예측이 없습니다. <b>예상차트 보기</b> 버튼을 누르면 1·3·5거래일 후 예측을 생성하고
|
||||||
|
차트에 점선으로 이어 붙입니다.
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function fmtPct(v: number | null): string {
|
||||||
|
if (v == null) return "-";
|
||||||
|
return `${(v * 100).toFixed(0)}%`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function fmtSignedPct(v: number | null): string {
|
||||||
|
if (v == null) return "-";
|
||||||
|
const pct = v * 100;
|
||||||
|
const sign = pct >= 0 ? "+" : "";
|
||||||
|
return `${sign}${pct.toFixed(2)}%`;
|
||||||
|
}
|
||||||
93
web/components/SearchBox.tsx
Normal file
93
web/components/SearchBox.tsx
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import Link from "next/link";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { api, type Symbol } from "../lib/api";
|
||||||
|
|
||||||
|
export function SearchBox({ autoFocus = false }: { autoFocus?: boolean }) {
|
||||||
|
const [q, setQ] = useState("");
|
||||||
|
const [items, setItems] = useState<Symbol[]>([]);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
const [seedOnly, setSeedOnly] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const term = q.trim();
|
||||||
|
if (!term) {
|
||||||
|
setItems([]);
|
||||||
|
setErr(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setLoading(true);
|
||||||
|
setErr(null);
|
||||||
|
const handle = setTimeout(async () => {
|
||||||
|
try {
|
||||||
|
const r = await api.search(term, seedOnly, 15);
|
||||||
|
setItems(r.items);
|
||||||
|
} catch (e) {
|
||||||
|
setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
setItems([]);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, 200);
|
||||||
|
return () => clearTimeout(handle);
|
||||||
|
}, [q, seedOnly]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full">
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<input
|
||||||
|
autoFocus={autoFocus}
|
||||||
|
value={q}
|
||||||
|
onChange={(e) => setQ(e.target.value)}
|
||||||
|
placeholder="종목명 또는 코드 (예: 삼성, 005930)"
|
||||||
|
className="w-full rounded-md border border-zinc-700 bg-zinc-900 px-4 py-3 text-base outline-none focus:border-zinc-500"
|
||||||
|
/>
|
||||||
|
<label className="flex shrink-0 items-center gap-2 text-xs text-zinc-400">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={seedOnly}
|
||||||
|
onChange={(e) => setSeedOnly(e.target.checked)}
|
||||||
|
/>
|
||||||
|
학습대상 10종목만
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="mt-2 min-h-[1.5rem] text-xs text-zinc-500">
|
||||||
|
{loading && "검색중…"}
|
||||||
|
{err && <span className="text-red-400">에러: {err}</span>}
|
||||||
|
{!loading && !err && q && items.length === 0 && "검색 결과 없음"}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{items.length > 0 && (
|
||||||
|
<ul className="mt-2 divide-y divide-zinc-800 rounded-md border border-zinc-800 bg-zinc-900/40">
|
||||||
|
{items.map((it) => (
|
||||||
|
<li key={it.code}>
|
||||||
|
<Link
|
||||||
|
href={`/${it.code}`}
|
||||||
|
className="flex items-center justify-between px-4 py-3 text-sm hover:bg-zinc-800/70"
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
<div className="font-medium text-zinc-100">
|
||||||
|
{it.name}
|
||||||
|
{it.is_seed && (
|
||||||
|
<span className="ml-2 rounded-full bg-emerald-900/60 px-2 py-0.5 text-[10px] font-semibold text-emerald-300">
|
||||||
|
SEED
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
{it.code} · {it.market}
|
||||||
|
{it.sector ? ` · ${it.sector}` : ""}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<span className="text-zinc-500">→</span>
|
||||||
|
</Link>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
167
web/components/StockChart.tsx
Normal file
167
web/components/StockChart.tsx
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useRef } from "react";
|
||||||
|
import {
|
||||||
|
createChart,
|
||||||
|
type CandlestickData,
|
||||||
|
type IChartApi,
|
||||||
|
type ISeriesApi,
|
||||||
|
type LineData,
|
||||||
|
type UTCTimestamp,
|
||||||
|
} from "lightweight-charts";
|
||||||
|
import type { ChartPayload, LatestPredictionResponse } from "../lib/api";
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
chart: ChartPayload;
|
||||||
|
prediction?: LatestPredictionResponse | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
function dateToUtcTs(d: string): UTCTimestamp {
|
||||||
|
// 'YYYY-MM-DD' → UTC midnight epoch seconds
|
||||||
|
return (Date.UTC(
|
||||||
|
Number(d.slice(0, 4)),
|
||||||
|
Number(d.slice(5, 7)) - 1,
|
||||||
|
Number(d.slice(8, 10)),
|
||||||
|
) / 1000) as UTCTimestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function StockChart({ chart, prediction }: Props) {
|
||||||
|
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
const chartRef = useRef<IChartApi | null>(null);
|
||||||
|
const candleRef = useRef<ISeriesApi<"Candlestick"> | null>(null);
|
||||||
|
const predRef = useRef<ISeriesApi<"Line"> | null>(null);
|
||||||
|
const predLowRef = useRef<ISeriesApi<"Line"> | null>(null);
|
||||||
|
const predHighRef = useRef<ISeriesApi<"Line"> | null>(null);
|
||||||
|
|
||||||
|
// create chart once
|
||||||
|
useEffect(() => {
|
||||||
|
if (!containerRef.current) return;
|
||||||
|
const c = createChart(containerRef.current, {
|
||||||
|
layout: {
|
||||||
|
background: { color: "transparent" },
|
||||||
|
textColor: "#cbd5e1",
|
||||||
|
},
|
||||||
|
grid: {
|
||||||
|
vertLines: { color: "#1f2937" },
|
||||||
|
horzLines: { color: "#1f2937" },
|
||||||
|
},
|
||||||
|
rightPriceScale: { borderColor: "#374151" },
|
||||||
|
timeScale: { borderColor: "#374151", timeVisible: false },
|
||||||
|
autoSize: true,
|
||||||
|
});
|
||||||
|
const candle = c.addCandlestickSeries({
|
||||||
|
upColor: "#22c55e",
|
||||||
|
downColor: "#ef4444",
|
||||||
|
borderUpColor: "#22c55e",
|
||||||
|
borderDownColor: "#ef4444",
|
||||||
|
wickUpColor: "#22c55e",
|
||||||
|
wickDownColor: "#ef4444",
|
||||||
|
});
|
||||||
|
chartRef.current = c;
|
||||||
|
candleRef.current = candle;
|
||||||
|
return () => {
|
||||||
|
c.remove();
|
||||||
|
chartRef.current = null;
|
||||||
|
candleRef.current = null;
|
||||||
|
predRef.current = null;
|
||||||
|
predLowRef.current = null;
|
||||||
|
predHighRef.current = null;
|
||||||
|
};
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
// push candle data
|
||||||
|
useEffect(() => {
|
||||||
|
if (!candleRef.current) return;
|
||||||
|
const data: CandlestickData[] = chart.ohlcv
|
||||||
|
.filter((p) => p.open !== null && p.high !== null && p.low !== null && p.close !== null)
|
||||||
|
.map((p) => ({
|
||||||
|
time: dateToUtcTs(p.date),
|
||||||
|
open: p.open as number,
|
||||||
|
high: p.high as number,
|
||||||
|
low: p.low as number,
|
||||||
|
close: p.close as number,
|
||||||
|
}));
|
||||||
|
candleRef.current.setData(data);
|
||||||
|
chartRef.current?.timeScale().fitContent();
|
||||||
|
}, [chart]);
|
||||||
|
|
||||||
|
// push prediction overlay
|
||||||
|
useEffect(() => {
|
||||||
|
if (!chartRef.current) return;
|
||||||
|
// remove previous overlay
|
||||||
|
if (predRef.current) {
|
||||||
|
chartRef.current.removeSeries(predRef.current);
|
||||||
|
predRef.current = null;
|
||||||
|
}
|
||||||
|
if (predLowRef.current) {
|
||||||
|
chartRef.current.removeSeries(predLowRef.current);
|
||||||
|
predLowRef.current = null;
|
||||||
|
}
|
||||||
|
if (predHighRef.current) {
|
||||||
|
chartRef.current.removeSeries(predHighRef.current);
|
||||||
|
predHighRef.current = null;
|
||||||
|
}
|
||||||
|
if (!prediction || !prediction.found || !prediction.steps?.length) return;
|
||||||
|
const baseDate = prediction.base_date!;
|
||||||
|
const baseClose = prediction.base_close;
|
||||||
|
if (!baseClose) return;
|
||||||
|
const sorted = [...prediction.steps].sort((a, b) => a.horizon - b.horizon);
|
||||||
|
|
||||||
|
const med: LineData[] = [
|
||||||
|
{ time: dateToUtcTs(baseDate), value: baseClose },
|
||||||
|
...sorted
|
||||||
|
.filter((s) => s.point_close !== null)
|
||||||
|
.map((s) => ({ time: dateToUtcTs(s.target_date), value: s.point_close as number })),
|
||||||
|
];
|
||||||
|
const lo: LineData[] = [
|
||||||
|
{ time: dateToUtcTs(baseDate), value: baseClose },
|
||||||
|
...sorted
|
||||||
|
.filter((s) => s.ci_low !== null)
|
||||||
|
.map((s) => ({ time: dateToUtcTs(s.target_date), value: s.ci_low as number })),
|
||||||
|
];
|
||||||
|
const hi: LineData[] = [
|
||||||
|
{ time: dateToUtcTs(baseDate), value: baseClose },
|
||||||
|
...sorted
|
||||||
|
.filter((s) => s.ci_high !== null)
|
||||||
|
.map((s) => ({ time: dateToUtcTs(s.target_date), value: s.ci_high as number })),
|
||||||
|
];
|
||||||
|
|
||||||
|
const medLine = chartRef.current.addLineSeries({
|
||||||
|
color: "#a78bfa",
|
||||||
|
lineWidth: 2,
|
||||||
|
lineStyle: 2, // dashed
|
||||||
|
priceLineVisible: false,
|
||||||
|
lastValueVisible: true,
|
||||||
|
title: "예측 median",
|
||||||
|
});
|
||||||
|
medLine.setData(med);
|
||||||
|
const loLine = chartRef.current.addLineSeries({
|
||||||
|
color: "#7c3aed",
|
||||||
|
lineWidth: 1,
|
||||||
|
lineStyle: 1,
|
||||||
|
priceLineVisible: false,
|
||||||
|
lastValueVisible: false,
|
||||||
|
title: "q10",
|
||||||
|
});
|
||||||
|
loLine.setData(lo);
|
||||||
|
const hiLine = chartRef.current.addLineSeries({
|
||||||
|
color: "#7c3aed",
|
||||||
|
lineWidth: 1,
|
||||||
|
lineStyle: 1,
|
||||||
|
priceLineVisible: false,
|
||||||
|
lastValueVisible: false,
|
||||||
|
title: "q90",
|
||||||
|
});
|
||||||
|
hiLine.setData(hi);
|
||||||
|
predRef.current = medLine;
|
||||||
|
predLowRef.current = loLine;
|
||||||
|
predHighRef.current = hiLine;
|
||||||
|
chartRef.current.timeScale().fitContent();
|
||||||
|
}, [prediction]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="h-[460px] w-full rounded-md border border-zinc-800 bg-zinc-900/30 p-2">
|
||||||
|
<div ref={containerRef} className="h-full w-full" />
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
178
web/lib/api.ts
Normal file
178
web/lib/api.ts
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
// Backend API client.
|
||||||
|
// NEXT_PUBLIC_API_BASE 는 docker-compose 에서 http://localhost:8000 으로 주입됨.
|
||||||
|
|
||||||
|
const RAW_BASE = process.env.NEXT_PUBLIC_API_BASE ?? "http://localhost:8000";
|
||||||
|
export const API_BASE = RAW_BASE.replace(/\/$/, "");
|
||||||
|
|
||||||
|
export type Symbol = {
|
||||||
|
code: string;
|
||||||
|
name: string;
|
||||||
|
market: string;
|
||||||
|
sector: string | null;
|
||||||
|
is_seed: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type SymbolSearch = {
|
||||||
|
q: string;
|
||||||
|
count: number;
|
||||||
|
items: Symbol[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type OhlcvPoint = {
|
||||||
|
date: string;
|
||||||
|
open: number | null;
|
||||||
|
high: number | null;
|
||||||
|
low: number | null;
|
||||||
|
close: number | null;
|
||||||
|
volume: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type SentimentPoint = {
|
||||||
|
date: string;
|
||||||
|
n_articles: number;
|
||||||
|
mean_score: number | null;
|
||||||
|
weighted_score: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type TradingValuePoint = {
|
||||||
|
date: string;
|
||||||
|
foreign_net: number | null;
|
||||||
|
institution_net: number | null;
|
||||||
|
individual_net: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ChartPayload = {
|
||||||
|
code: string;
|
||||||
|
name: string;
|
||||||
|
market: string;
|
||||||
|
range: { from: string; to: string };
|
||||||
|
ohlcv: OhlcvPoint[];
|
||||||
|
sentiment: SentimentPoint[];
|
||||||
|
trading_value: TradingValuePoint[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type PredictionStep = {
|
||||||
|
horizon: number;
|
||||||
|
target_idx?: number;
|
||||||
|
point_close: number;
|
||||||
|
ci_low: number;
|
||||||
|
ci_high: number;
|
||||||
|
prob_up: number;
|
||||||
|
prob_flat: number;
|
||||||
|
prob_down: number;
|
||||||
|
direction: "up" | "flat" | "down";
|
||||||
|
expected_return: number;
|
||||||
|
target_date?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type PredictResponse = {
|
||||||
|
code: string;
|
||||||
|
base_date: string;
|
||||||
|
base_close: number;
|
||||||
|
sources_used: string[];
|
||||||
|
steps: PredictionStep[];
|
||||||
|
saved_prediction_ids: number[];
|
||||||
|
user_triggered: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LatestPredictionStep = {
|
||||||
|
predicted_at: string | null;
|
||||||
|
target_date: string;
|
||||||
|
horizon: number;
|
||||||
|
direction: "up" | "flat" | "down" | string;
|
||||||
|
prob_up: number | null;
|
||||||
|
prob_flat: number | null;
|
||||||
|
prob_down: number | null;
|
||||||
|
expected_return: number | null;
|
||||||
|
point_close: number | null;
|
||||||
|
ci_low: number | null;
|
||||||
|
ci_high: number | null;
|
||||||
|
user_triggered: boolean;
|
||||||
|
features_snapshot: unknown;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LatestPredictionResponse = {
|
||||||
|
code: string;
|
||||||
|
name?: string;
|
||||||
|
found: boolean;
|
||||||
|
base_date?: string;
|
||||||
|
base_close?: number | null;
|
||||||
|
steps: LatestPredictionStep[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type MetricsRow = {
|
||||||
|
model: string;
|
||||||
|
horizon: number;
|
||||||
|
n: number;
|
||||||
|
hit_rate: number | null;
|
||||||
|
mae: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type MetricsResponse = {
|
||||||
|
code?: string;
|
||||||
|
name?: string;
|
||||||
|
window_days: number;
|
||||||
|
range: { from: string; to: string };
|
||||||
|
by_model_horizon: MetricsRow[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type NewsItem = {
|
||||||
|
source: string;
|
||||||
|
published_at: string | null;
|
||||||
|
title: string;
|
||||||
|
url: string;
|
||||||
|
sentiment_score: number | null;
|
||||||
|
sentiment_label: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type NewsResponse = {
|
||||||
|
code: string;
|
||||||
|
name: string;
|
||||||
|
count: number;
|
||||||
|
items: NewsItem[];
|
||||||
|
};
|
||||||
|
|
||||||
|
async function getJson<T>(path: string, init?: RequestInit): Promise<T> {
|
||||||
|
const res = await fetch(`${API_BASE}${path}`, {
|
||||||
|
...init,
|
||||||
|
headers: {
|
||||||
|
Accept: "application/json",
|
||||||
|
...(init?.headers ?? {}),
|
||||||
|
},
|
||||||
|
cache: "no-store",
|
||||||
|
});
|
||||||
|
if (!res.ok) {
|
||||||
|
const text = await res.text().catch(() => "");
|
||||||
|
throw new Error(`API ${path} → ${res.status} ${text || res.statusText}`);
|
||||||
|
}
|
||||||
|
return (await res.json()) as T;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const api = {
|
||||||
|
search: (q: string, seedOnly = false, limit = 20) =>
|
||||||
|
getJson<SymbolSearch>(
|
||||||
|
`/api/symbols/search?q=${encodeURIComponent(q)}&limit=${limit}&seed_only=${seedOnly}`,
|
||||||
|
),
|
||||||
|
getSymbol: (code: string) => getJson<Symbol>(`/api/symbols/${encodeURIComponent(code)}`),
|
||||||
|
getChart: (code: string, days = 180) =>
|
||||||
|
getJson<ChartPayload>(`/api/chart/${encodeURIComponent(code)}?days=${days}`),
|
||||||
|
predict: (code: string, horizons = "1,3,5") =>
|
||||||
|
getJson<PredictResponse>(
|
||||||
|
`/api/predict/${encodeURIComponent(code)}?horizons=${encodeURIComponent(horizons)}`,
|
||||||
|
{ method: "POST" },
|
||||||
|
),
|
||||||
|
latestPrediction: (code: string) =>
|
||||||
|
getJson<LatestPredictionResponse>(`/api/predict/${encodeURIComponent(code)}/latest`),
|
||||||
|
metrics: (code: string, windowDays = 30) =>
|
||||||
|
getJson<MetricsResponse>(
|
||||||
|
`/api/metrics/${encodeURIComponent(code)}?window_days=${windowDays}`,
|
||||||
|
),
|
||||||
|
overallMetrics: (windowDays = 30) =>
|
||||||
|
getJson<MetricsResponse>(`/api/metrics?window_days=${windowDays}`),
|
||||||
|
news: (code: string, limit = 20, source?: string) =>
|
||||||
|
getJson<NewsResponse>(
|
||||||
|
`/api/news/${encodeURIComponent(code)}?limit=${limit}${
|
||||||
|
source ? `&source=${encodeURIComponent(source)}` : ""
|
||||||
|
}`,
|
||||||
|
),
|
||||||
|
};
|
||||||
5772
web/package-lock.json
generated
Normal file
5772
web/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,7 +6,9 @@
|
|||||||
"dev": "next dev -p 3000 -H 0.0.0.0",
|
"dev": "next dev -p 3000 -H 0.0.0.0",
|
||||||
"build": "next build",
|
"build": "next build",
|
||||||
"start": "next start -p 3000 -H 0.0.0.0",
|
"start": "next start -p 3000 -H 0.0.0.0",
|
||||||
"lint": "next lint"
|
"lint": "next lint",
|
||||||
|
"typecheck": "tsc --noEmit",
|
||||||
|
"check": "npm run typecheck && npm run lint"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"next": "14.2.3",
|
"next": "14.2.3",
|
||||||
@@ -21,6 +23,8 @@
|
|||||||
"typescript": "5.4.5",
|
"typescript": "5.4.5",
|
||||||
"tailwindcss": "3.4.4",
|
"tailwindcss": "3.4.4",
|
||||||
"postcss": "8.4.38",
|
"postcss": "8.4.38",
|
||||||
"autoprefixer": "10.4.19"
|
"autoprefixer": "10.4.19",
|
||||||
|
"eslint": "8.57.0",
|
||||||
|
"eslint-config-next": "14.2.3"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,9 @@
|
|||||||
"isolatedModules": true,
|
"isolatedModules": true,
|
||||||
"jsx": "preserve",
|
"jsx": "preserve",
|
||||||
"incremental": true,
|
"incremental": true,
|
||||||
|
"forceConsistentCasingInFileNames": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"noImplicitOverride": true,
|
||||||
"plugins": [{ "name": "next" }],
|
"plugins": [{ "name": "next" }],
|
||||||
"paths": { "@/*": ["./*"] }
|
"paths": { "@/*": ["./*"] }
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user