Compare commits
30 Commits
239b104a2b
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bf898d78be | ||
|
|
73593adb5c | ||
|
|
323061df02 | ||
|
|
ea885973c7 | ||
|
|
e0edc8f1e3 | ||
|
|
44873ddb39 | ||
|
|
e610599879 | ||
|
|
0a5c634680 | ||
|
|
928c2160f9 | ||
|
|
78388d347e | ||
|
|
659871118f | ||
|
|
bd47198088 | ||
|
|
fa817b31e4 | ||
|
|
96b7afd443 | ||
|
|
89651251a4 | ||
|
|
296bd6dccd | ||
|
|
2c42c1151c | ||
|
|
9c7c02703a | ||
|
|
e08f3b0765 | ||
|
|
eb56025d9c | ||
|
|
6c792305a9 | ||
|
|
5e6ce11491 | ||
|
|
0af556396e | ||
|
|
f84b460e54 | ||
|
|
bc016ab76d | ||
|
|
4fb6cec383 | ||
|
|
41ee9d5bb0 | ||
|
|
bf4fb01146 | ||
|
|
b1ca6ab5d3 | ||
|
|
edda01adbf |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -20,6 +20,7 @@ build/
|
|||||||
# Models / artifacts (downloaded HF caches, trained LGBM)
|
# Models / artifacts (downloaded HF caches, trained LGBM)
|
||||||
backend/artifacts/
|
backend/artifacts/
|
||||||
backend/.cache/
|
backend/.cache/
|
||||||
|
backend/data/
|
||||||
.huggingface/
|
.huggingface/
|
||||||
|
|
||||||
# Node
|
# Node
|
||||||
|
|||||||
42
README.md
42
README.md
@@ -129,21 +129,39 @@ stock_chart_site/
|
|||||||
|
|
||||||
## 진행 계획
|
## 진행 계획
|
||||||
|
|
||||||
- Phase 0 — 스캐폴드 (현재): Docker 환경 + DB 스키마 + 빈 FastAPI/Next.js + build.bat
|
- [x] Phase 0 — 스캐폴드: Docker 환경 + DB 스키마 + FastAPI/Next.js + build.bat
|
||||||
- Phase 1a — pykrx 데이터 파이프: 일봉/외인기관/지수 + DART + 뉴스 RSS + 거시
|
- [x] Phase 1a — pykrx 데이터 파이프: 일봉/외인기관/지수 + DART + 뉴스 RSS + 거시
|
||||||
- Phase 1b — KIS EOD (키 받으면)
|
- [x] Phase 1b — KIS read-only EOD (스모크 통과)
|
||||||
- Phase 2 — KR-FinBERT 감성 점수 + 일별 집계
|
- [x] Phase 2 — KR-FinBERT 감성 점수 + 일별 집계 뷰
|
||||||
- Phase 3 — Chronos zero-shot 예측 적재
|
- [x] Phase 3 — Chronos zero-shot 예측 어댑터 + 피처 빌더
|
||||||
- Phase 4 — LightGBM walk-forward + `prediction_outcomes` 누적 시작
|
- [x] Phase 4 — LightGBM walk-forward + ensemble + 매칭/재학습 잡
|
||||||
- Phase 5 — FastAPI 엔드포인트 (검색, 차트, on-demand 예측, 메트릭)
|
- [x] Phase 5 — FastAPI 엔드포인트 (검색/차트/예측/메트릭/뉴스)
|
||||||
- Phase 6 — Next.js UI (검색 + 현재 차트 + 예상차트 토글)
|
- [x] Phase 6 — Next.js UI (검색 + 현재 차트 + 예상차트 overlay)
|
||||||
- Phase 7 (옵션) — 백테스트 페이지 + 주간 자동 재학습
|
- [ ] Phase 7 (옵션) — 백테스트 페이지 + Chronos/LGBM 단독 shadow 예측
|
||||||
|
|
||||||
|
### API 엔드포인트 (요약)
|
||||||
|
|
||||||
|
| 메서드 | 경로 | 설명 |
|
||||||
|
|---|---|---|
|
||||||
|
| GET | `/health`, `/health/db`, `/health/keys` | 헬스/외부키 ping |
|
||||||
|
| POST | `/api/refresh/{code}?lookback_days=N` | 수동 갱신 |
|
||||||
|
| GET | `/api/symbols/search?q=&seed_only=` | 종목 검색 (trigram + prefix) |
|
||||||
|
| GET | `/api/symbols/{code}` | 종목 메타 |
|
||||||
|
| GET | `/api/chart/{code}?days=N` | OHLCV + 감성 + 외인기관거래대금 |
|
||||||
|
| POST | `/api/predict/{code}?horizons=1,3,5` | on-demand 앙상블 예측 (user_triggered) |
|
||||||
|
| GET | `/api/predict/{code}/latest` | 최신 base_date 예측 묶음 (UI overlay) |
|
||||||
|
| GET | `/api/metrics/{code}?window_days=N` | 종목 hit_rate / mae |
|
||||||
|
| GET | `/api/news/{code}?limit=N&source=` | 최근 뉴스/공시 + 감성 |
|
||||||
|
|
||||||
## 동작 모델 메모
|
## 동작 모델 메모
|
||||||
|
|
||||||
- 예측 트리거: 사용자가 "예상차트 보기" 누른 종목에 대해 즉시 inference. 결과는 `predictions(user_triggered=TRUE)` 로 저장.
|
- 예측 트리거: 사용자가 "예상차트 보기" 누른 종목에 대해 즉시 inference. 결과는 세 종류 행으로 적재:
|
||||||
- 매칭 배치: 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30, 종가 확정 후 16:00 ~ 16:30 KST 사이) `user_triggered=TRUE` 인 예측 중 `target_date == 오늘 거래일`인 행들에 대해 실제 종가/방향과 매칭 → `prediction_outcomes` 적재. 주말/공휴일이면 다음 거래일로 이월.
|
- `model='ensemble'` (user_triggered=TRUE) — UI 가 표시하는 최종 예측
|
||||||
- 주간 02:00 (일요일): 종목/모델별 최근 30일 hit rate 기반으로 앙상블 가중치를 자동 보정. hit rate가 임계 미만이면 LGBM 재학습.
|
- `model='chronos'` (user_triggered=FALSE, shadow) — Chronos 단독 성능 추적용
|
||||||
|
- `model='lgbm'` (user_triggered=FALSE, shadow) — LGBM 단독 성능 추적용
|
||||||
|
- 매칭 배치: 평일 16:30 KST. `target_date <= today AND outcomes 미존재` 인 모든 행에 대해 `target_date` 이상 `today` 이하 범위의 **최초 거래일 종가**를 actual_close 로 사용 → 주말/공휴일 자동 이월. shadow 행도 함께 매칭됨.
|
||||||
|
- 주간 02:00 (일요일): 시드 10종목 × horizons LGBM 재학습. 최근 30일 prediction_outcomes 의 chronos vs lgbm hit_rate 비교 → `w_chronos = clamp(0.1, hr_c/(hr_c+hr_l), 0.9)` 공식으로 `ensemble_weights` upsert. 모델별 표본이 10 미만이면 기본값(0.6/0.4) 유지.
|
||||||
|
- DB bootstrap: 백엔드 첫 부팅 시 lifespan 에서 idempotent migration + symbols 시드(비어있을 때만 pykrx 전 종목 적재) 자동 수행. `BOOTSTRAP_DISABLED=1` 로 비활성화 가능.
|
||||||
|
|
||||||
## 안전/한계
|
## 안전/한계
|
||||||
|
|
||||||
|
|||||||
@@ -7,27 +7,55 @@ ENV DEBIAN_FRONTEND=noninteractive \
|
|||||||
PYTHONDONTWRITEBYTECODE=1 \
|
PYTHONDONTWRITEBYTECODE=1 \
|
||||||
PIP_NO_CACHE_DIR=1 \
|
PIP_NO_CACHE_DIR=1 \
|
||||||
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
PIP_DISABLE_PIP_VERSION_CHECK=1 \
|
||||||
|
PYTHONPATH=/app \
|
||||||
TZ=Asia/Seoul
|
TZ=Asia/Seoul
|
||||||
|
|
||||||
|
# Ubuntu 22.04 의 python3-pip 는 python3.10 을 가리키므로 설치하지 않고,
|
||||||
|
# python3.11 + get-pip.py 로 3.11 전용 pip 를 부트스트랩한다.
|
||||||
|
# (Debian/Ubuntu 의 시스템 python 은 ensurepip 가 막혀 있어 get-pip.py 가 가장 깔끔함.)
|
||||||
|
# 이후 모든 호출은 `python -m pip` 로 통일해 인터프리터/스크립트 불일치를 차단.
|
||||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
python3.11 python3.11-venv python3-pip \
|
python3.11 python3.11-venv \
|
||||||
build-essential git curl ca-certificates tzdata \
|
build-essential git curl ca-certificates tzdata \
|
||||||
libgomp1 \
|
libgomp1 \
|
||||||
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python \
|
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python \
|
||||||
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python3 \
|
&& ln -sf /usr/bin/python3.11 /usr/local/bin/python3 \
|
||||||
|
&& curl -sSL https://bootstrap.pypa.io/get-pip.py -o /tmp/get-pip.py \
|
||||||
|
&& python /tmp/get-pip.py \
|
||||||
|
&& rm /tmp/get-pip.py \
|
||||||
|
&& python -m pip install --upgrade pip wheel \
|
||||||
|
&& python -m pip install "setuptools<80" \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
# setuptools 80+ 은 pkg_resources 모듈을 제거함. pykrx 가 `import pkg_resources` 를
|
||||||
|
# 하므로 80 미만으로 핀. 아래 reqs.txt 단계에서 다른 deps 가 setuptools 재upgrade 를
|
||||||
|
# 트리거하지 않도록 별도 명령으로 고정.
|
||||||
|
|
||||||
|
# Sanity check: 이 출력은 빌드 로그에 박혀서 다음에 인터프리터 불일치 의심될 때 즉시 확인 가능.
|
||||||
|
RUN python -V && python -m pip -V
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
COPY pyproject.toml ./
|
COPY pyproject.toml ./
|
||||||
|
|
||||||
# Install PyTorch (CUDA 12.1 wheels) first so the rest of deps don't downgrade it.
|
# Install PyTorch (CUDA 12.1 wheels) first so the rest of deps don't downgrade it.
|
||||||
RUN pip install --extra-index-url https://download.pytorch.org/whl/cu121 \
|
RUN python -m pip install --extra-index-url https://download.pytorch.org/whl/cu121 \
|
||||||
torch==2.3.1 torchvision==0.18.1
|
torch==2.3.1 torchvision==0.18.1
|
||||||
RUN pip install --no-deps -e . || true
|
|
||||||
RUN pip install -e .
|
# Install runtime deps from pyproject.toml WITHOUT installing the project itself.
|
||||||
|
# - 이전 `pip install -e .` 은 app/ 가 아직 COPY 되기 전이라 packages.find 결과가 비고,
|
||||||
|
# ubuntu 22.04 기본 pip 의 PEP 660 editable hook 과 충돌해 실패했음.
|
||||||
|
# - 런타임에는 PYTHONPATH=/app 으로 `app.*` 임포트가 동작하므로 프로젝트 설치 자체가 불필요.
|
||||||
|
# - deps 만 별도 레이어로 캐시 → 코드 변경 시 ML 휠 재빌드 회피.
|
||||||
|
RUN python -c "import tomllib; \
|
||||||
|
deps = tomllib.load(open('pyproject.toml','rb'))['project']['dependencies']; \
|
||||||
|
open('/tmp/reqs.txt','w').write('\n'.join(deps))" \
|
||||||
|
&& python -m pip install -r /tmp/reqs.txt \
|
||||||
|
&& python -m pip install "setuptools<80" \
|
||||||
|
&& python -c "import pkg_resources; print('pkg_resources OK from', pkg_resources.__file__)"
|
||||||
|
|
||||||
COPY app ./app
|
COPY app ./app
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
# uvicorn 콘솔 스크립트 대신 `python -m uvicorn` 으로 호출 — 3.11 인터프리터에서 실행됨을 보장.
|
||||||
|
CMD ["python", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"]
|
||||||
|
|||||||
340
backend/app/api/chart.py
Normal file
340
backend/app/api/chart.py
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
"""차트 데이터 API: OHLCV + 보조 데이터 (감성, 거시).
|
||||||
|
|
||||||
|
UI: /code 페이지가 호출 → lightweight-charts 캔들 데이터로 사용.
|
||||||
|
|
||||||
|
interval 파라미터로 캔들 단위 선택:
|
||||||
|
- "10m" : 당일 10분봉. ohlcv_1m 을 time_bucket 으로 10분 단위 집계.
|
||||||
|
stale (>10분) 이면 KIS inquire-time-itemchartprice 로 즉시 보충.
|
||||||
|
- "1d" : 일봉. ohlcv_daily 직접 조회. 비어있으면 pykrx auto-refresh.
|
||||||
|
- "1w" : 주봉. ohlcv_daily 를 date_trunc('week') 로 집계.
|
||||||
|
- "1mo" : 월봉. ohlcv_daily 를 date_trunc('month') 로 집계.
|
||||||
|
|
||||||
|
10m 외에는 date 필드가 'YYYY-MM-DD' ISO date 문자열,
|
||||||
|
10m 일 때는 'YYYY-MM-DDTHH:MM:SS' ISO datetime (KST) 으로 통일.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import date, datetime, time as dtime, timedelta, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/chart", tags=["chart"])
|
||||||
|
|
||||||
|
ALLOWED_INTERVALS = ("10m", "1d", "1w", "1mo")
|
||||||
|
KST = timezone(timedelta(hours=9))
|
||||||
|
|
||||||
|
|
||||||
|
def _query_ohlcv_daily(conn, code: str, start: date, end: date):
|
||||||
|
return conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT date, open, high, low, close, volume
|
||||||
|
FROM ohlcv_daily
|
||||||
|
WHERE code = :c AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start, "e": end},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
def _query_ohlcv_bucketed(conn, code: str, start: date, end: date, trunc: str):
|
||||||
|
"""1d → 1w/1mo 집계. date_trunc 로 bucket 잡고, 첫/마지막/최고/최저/합 집계.
|
||||||
|
|
||||||
|
open=bucket 첫 거래일 시가, close=마지막 거래일 종가. PostgreSQL window 함수로 구한다.
|
||||||
|
"""
|
||||||
|
return conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
WITH base AS (
|
||||||
|
SELECT date_trunc(:trunc, date)::date AS bucket,
|
||||||
|
date, open, high, low, close, volume
|
||||||
|
FROM ohlcv_daily
|
||||||
|
WHERE code = :c AND date BETWEEN :s AND :e
|
||||||
|
),
|
||||||
|
ranked AS (
|
||||||
|
SELECT bucket, date, open, high, low, close, volume,
|
||||||
|
ROW_NUMBER() OVER (PARTITION BY bucket ORDER BY date ASC) AS rn_first,
|
||||||
|
ROW_NUMBER() OVER (PARTITION BY bucket ORDER BY date DESC) AS rn_last
|
||||||
|
FROM base
|
||||||
|
)
|
||||||
|
SELECT bucket AS date,
|
||||||
|
MAX(open) FILTER (WHERE rn_first = 1) AS open,
|
||||||
|
MAX(high) AS high,
|
||||||
|
MIN(low) AS low,
|
||||||
|
MAX(close) FILTER (WHERE rn_last = 1) AS close,
|
||||||
|
SUM(volume) AS volume
|
||||||
|
FROM ranked
|
||||||
|
GROUP BY bucket
|
||||||
|
ORDER BY bucket
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start, "e": end, "trunc": trunc},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
def _query_ohlcv_10m(conn, code: str, start_ts: datetime, end_ts: datetime):
|
||||||
|
"""ohlcv_1m → 10분봉. TimescaleDB time_bucket 으로 10분 단위 집계.
|
||||||
|
|
||||||
|
first()/last() 는 TimescaleDB 의 집계함수.
|
||||||
|
"""
|
||||||
|
return conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT time_bucket(INTERVAL '10 minutes', ts) AS bucket,
|
||||||
|
first(open, ts) AS open,
|
||||||
|
MAX(high) AS high,
|
||||||
|
MIN(low) AS low,
|
||||||
|
last(close, ts) AS close,
|
||||||
|
SUM(volume) AS volume
|
||||||
|
FROM ohlcv_1m
|
||||||
|
WHERE code = :c AND ts >= :s AND ts < :e
|
||||||
|
GROUP BY bucket
|
||||||
|
ORDER BY bucket
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start_ts, "e": end_ts},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
|
||||||
|
def _upsert_ohlcv_1m(conn, code: str, rows: list[dict]) -> int:
|
||||||
|
"""KIS 분봉 응답을 ohlcv_1m 에 UPSERT. 같은 (code, ts) 는 덮어쓰기 (장중 갱신용)."""
|
||||||
|
if not rows:
|
||||||
|
return 0
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO ohlcv_1m (code, ts, open, high, low, close, volume)
|
||||||
|
VALUES (:code, :ts, :open, :high, :low, :close, :volume)
|
||||||
|
ON CONFLICT (code, ts) DO UPDATE SET
|
||||||
|
open = EXCLUDED.open,
|
||||||
|
high = EXCLUDED.high,
|
||||||
|
low = EXCLUDED.low,
|
||||||
|
close = EXCLUDED.close,
|
||||||
|
volume = EXCLUDED.volume
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
[{"code": code, **r} for r in rows],
|
||||||
|
)
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def _intraday_window_today() -> tuple[datetime, datetime]:
|
||||||
|
"""오늘 KST 의 장 시간대 윈도우 (08:50 ~ 16:00). 토/일은 직전 영업일."""
|
||||||
|
now = datetime.now(KST)
|
||||||
|
d = now.date()
|
||||||
|
# 주말이면 직전 금요일로
|
||||||
|
while d.weekday() >= 5:
|
||||||
|
d -= timedelta(days=1)
|
||||||
|
start = datetime.combine(d, dtime(8, 50), tzinfo=KST)
|
||||||
|
end = datetime.combine(d, dtime(16, 0), tzinfo=KST)
|
||||||
|
return start, end
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_intraday_fresh(conn, code: str) -> str:
|
||||||
|
"""오늘 윈도우의 ohlcv_1m 을 필요한 만큼만 KIS 에서 보충.
|
||||||
|
|
||||||
|
분기:
|
||||||
|
- 주말: KIS 분봉 endpoint 는 "당일" 만 지원 → 호출하지 않음. 'weekend'.
|
||||||
|
- 장외 (평일 09:00 이전 또는 15:35 이후) + 이미 오늘 데이터 있음: 'cached_closed'.
|
||||||
|
(마감 후엔 데이터 늘지 않으므로 KIS 호출 의미 없음)
|
||||||
|
- 장중 + last_ts 가 10분 이내: 'fresh' (DB 만 읽음)
|
||||||
|
- 그 외 (장중 stale / 장 막 끝나서 마지막 마감 데이터 1회 필요 / 캐시 비어있음):
|
||||||
|
last_ts+1m ~ now 사이의 빈 구간을 fetch_minute_range 로 페이지네이션 채움.
|
||||||
|
DB 캐시가 그날 데이터를 이미 갖고 있으면 자연히 호출 1~2 페이지로 끝.
|
||||||
|
|
||||||
|
Returns: 'fresh' | 'refreshed' | 'cached_closed' | 'weekend' |
|
||||||
|
'skipped_missing_key' | 'failed' | 'no_data'
|
||||||
|
"""
|
||||||
|
now = datetime.now(KST)
|
||||||
|
if now.weekday() >= 5:
|
||||||
|
return "weekend"
|
||||||
|
|
||||||
|
win_start, win_end = _intraday_window_today()
|
||||||
|
last_ts = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT MAX(ts) FROM ohlcv_1m WHERE code = :c AND ts >= :s AND ts < :e"
|
||||||
|
),
|
||||||
|
{"c": code, "s": win_start, "e": win_end},
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
market_open = dtime(9, 0)
|
||||||
|
market_close_buffer = dtime(15, 35)
|
||||||
|
in_session = market_open <= now.time() <= market_close_buffer
|
||||||
|
|
||||||
|
# 장외이고 이미 오늘 데이터 있음 → 추가 호출 불필요
|
||||||
|
if not in_session and last_ts is not None:
|
||||||
|
return "cached_closed"
|
||||||
|
|
||||||
|
# 장중 + 10분 이내 갱신 → 추가 호출 불필요
|
||||||
|
if in_session and last_ts is not None and (now - last_ts) < timedelta(minutes=10):
|
||||||
|
return "fresh"
|
||||||
|
|
||||||
|
# fetch 윈도우 = [last_ts+1m or win_start, min(now, win_end)]
|
||||||
|
fetch_to = min(now, win_end)
|
||||||
|
if last_ts is not None and last_ts >= win_start:
|
||||||
|
fetch_from = last_ts + timedelta(minutes=1)
|
||||||
|
else:
|
||||||
|
fetch_from = win_start
|
||||||
|
if fetch_from >= fetch_to:
|
||||||
|
return "fresh"
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.fetch.kis import SkippedMissingKey, fetch_minute_range
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
return "failed"
|
||||||
|
|
||||||
|
try:
|
||||||
|
rows = fetch_minute_range(code, fetch_from, fetch_to)
|
||||||
|
except SkippedMissingKey:
|
||||||
|
return "skipped_missing_key"
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("intraday refresh failed for %s", code)
|
||||||
|
return "failed"
|
||||||
|
|
||||||
|
if not rows:
|
||||||
|
return "no_data"
|
||||||
|
_upsert_ohlcv_1m(conn, code, rows)
|
||||||
|
conn.commit()
|
||||||
|
return "refreshed"
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}")
|
||||||
|
def get_chart(
|
||||||
|
code: str,
|
||||||
|
days: int = Query(default=180, ge=1, le=3650),
|
||||||
|
interval: str = Query(default="1d"),
|
||||||
|
include_sentiment: bool = Query(default=True),
|
||||||
|
include_trading_value: bool = Query(default=True),
|
||||||
|
) -> dict:
|
||||||
|
if interval not in ALLOWED_INTERVALS:
|
||||||
|
raise HTTPException(status_code=400, detail=f"interval must be one of {ALLOWED_INTERVALS}")
|
||||||
|
|
||||||
|
eng = get_engine()
|
||||||
|
end = date.today()
|
||||||
|
start = end - timedelta(days=days)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
symbol = conn.execute(
|
||||||
|
text("SELECT code, name, market FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not symbol:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
|
||||||
|
ohlcv: list[dict] = []
|
||||||
|
intraday_status: str | None = None
|
||||||
|
|
||||||
|
if interval == "10m":
|
||||||
|
intraday_status = _ensure_intraday_fresh(conn, code)
|
||||||
|
win_start, win_end = _intraday_window_today()
|
||||||
|
rows = _query_ohlcv_10m(conn, code, win_start, win_end)
|
||||||
|
ohlcv = [
|
||||||
|
{
|
||||||
|
# KST aware datetime → ISO datetime. 프론트에서 Date 파싱.
|
||||||
|
"date": (r[0].astimezone(KST) if r[0].tzinfo else r[0].replace(tzinfo=KST))
|
||||||
|
.strftime("%Y-%m-%dT%H:%M:%S"),
|
||||||
|
"open": float(r[1]) if r[1] is not None else None,
|
||||||
|
"high": float(r[2]) if r[2] is not None else None,
|
||||||
|
"low": float(r[3]) if r[3] is not None else None,
|
||||||
|
"close": float(r[4]) if r[4] is not None else None,
|
||||||
|
"volume": int(r[5]) if r[5] is not None else None,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
if interval == "1d":
|
||||||
|
rows = _query_ohlcv_daily(conn, code, start, end)
|
||||||
|
elif interval == "1w":
|
||||||
|
rows = _query_ohlcv_bucketed(conn, code, start, end, "week")
|
||||||
|
else: # "1mo"
|
||||||
|
rows = _query_ohlcv_bucketed(conn, code, start, end, "month")
|
||||||
|
|
||||||
|
if not rows and interval == "1d":
|
||||||
|
# 첫 방문 → pykrx auto-refresh.
|
||||||
|
try:
|
||||||
|
from app.pipelines.refresh_one import refresh_code
|
||||||
|
logger.info("chart: ohlcv_daily empty for %s — auto-refresh", code)
|
||||||
|
refresh_code(symbol[0], symbol[1], lookback_days=max(days, 365))
|
||||||
|
rows = _query_ohlcv_daily(conn, code, start, end)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("chart: auto-refresh failed for %s", code)
|
||||||
|
|
||||||
|
ohlcv = [
|
||||||
|
{
|
||||||
|
"date": str(r[0]),
|
||||||
|
"open": float(r[1]) if r[1] is not None else None,
|
||||||
|
"high": float(r[2]) if r[2] is not None else None,
|
||||||
|
"low": float(r[3]) if r[3] is not None else None,
|
||||||
|
"close": float(r[4]) if r[4] is not None else None,
|
||||||
|
"volume": int(r[5]) if r[5] is not None else None,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
sentiment: list[dict] = []
|
||||||
|
if include_sentiment and interval != "10m":
|
||||||
|
try:
|
||||||
|
s_rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT date, n_articles, mean_score, weighted_score
|
||||||
|
FROM v_sentiment_daily
|
||||||
|
WHERE code = :c AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start, "e": end},
|
||||||
|
).all()
|
||||||
|
sentiment = [
|
||||||
|
{
|
||||||
|
"date": str(r[0]),
|
||||||
|
"n_articles": int(r[1]) if r[1] is not None else 0,
|
||||||
|
"mean_score": float(r[2]) if r[2] is not None else None,
|
||||||
|
"weighted_score": float(r[3]) if r[3] is not None else None,
|
||||||
|
}
|
||||||
|
for r in s_rows
|
||||||
|
]
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
sentiment = []
|
||||||
|
|
||||||
|
trading: list[dict] = []
|
||||||
|
if include_trading_value and interval != "10m":
|
||||||
|
tv_rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT date, foreign_net, institution_net, individual_net
|
||||||
|
FROM trading_value_daily
|
||||||
|
WHERE code = :c AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start, "e": end},
|
||||||
|
).all()
|
||||||
|
trading = [
|
||||||
|
{
|
||||||
|
"date": str(r[0]),
|
||||||
|
"foreign_net": float(r[1]) if r[1] is not None else None,
|
||||||
|
"institution_net": float(r[2]) if r[2] is not None else None,
|
||||||
|
"individual_net": float(r[3]) if r[3] is not None else None,
|
||||||
|
}
|
||||||
|
for r in tv_rows
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": symbol[0],
|
||||||
|
"name": symbol[1],
|
||||||
|
"market": symbol[2],
|
||||||
|
"interval": interval,
|
||||||
|
"intraday_status": intraday_status,
|
||||||
|
"range": {"from": str(start), "to": str(end)},
|
||||||
|
"today": date.today().isoformat(),
|
||||||
|
"ohlcv": ohlcv,
|
||||||
|
"sentiment": sentiment,
|
||||||
|
"trading_value": trading,
|
||||||
|
}
|
||||||
101
backend/app/api/metrics.py
Normal file
101
backend/app/api/metrics.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""모델 성능 메트릭 API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/metrics", tags=["metrics"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}")
|
||||||
|
def code_metrics(
|
||||||
|
code: str,
|
||||||
|
window_days: int = Query(default=30, ge=1, le=365),
|
||||||
|
) -> dict:
|
||||||
|
"""code 의 최근 window_days 윈도우 prediction_outcomes 집계."""
|
||||||
|
eng = get_engine()
|
||||||
|
end = date.today()
|
||||||
|
start = end - timedelta(days=window_days)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
sym = conn.execute(
|
||||||
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not sym:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT model, horizon,
|
||||||
|
COUNT(*) AS n,
|
||||||
|
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||||
|
AVG(abs_error) AS mae
|
||||||
|
FROM prediction_outcomes
|
||||||
|
WHERE code = :c AND resolved_at >= :s
|
||||||
|
GROUP BY model, horizon
|
||||||
|
ORDER BY model, horizon
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "s": start},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": sym[0],
|
||||||
|
"name": sym[1],
|
||||||
|
"window_days": window_days,
|
||||||
|
"range": {"from": str(start), "to": str(end)},
|
||||||
|
"by_model_horizon": [
|
||||||
|
{
|
||||||
|
"model": r[0],
|
||||||
|
"horizon": int(r[1]),
|
||||||
|
"n": int(r[2]),
|
||||||
|
"hit_rate": float(r[3]) if r[3] is not None else None,
|
||||||
|
"mae": float(r[4]) if r[4] is not None else None,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
def overall_metrics(
|
||||||
|
window_days: int = Query(default=30, ge=1, le=365),
|
||||||
|
) -> dict:
|
||||||
|
"""전체 시드 종목 누적 메트릭."""
|
||||||
|
eng = get_engine()
|
||||||
|
end = date.today()
|
||||||
|
start = end - timedelta(days=window_days)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT po.model, po.horizon,
|
||||||
|
COUNT(*) AS n,
|
||||||
|
AVG(CASE WHEN po.direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||||
|
AVG(po.abs_error) AS mae
|
||||||
|
FROM prediction_outcomes po
|
||||||
|
WHERE po.resolved_at >= :s
|
||||||
|
GROUP BY po.model, po.horizon
|
||||||
|
ORDER BY po.model, po.horizon
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"s": start},
|
||||||
|
).all()
|
||||||
|
return {
|
||||||
|
"window_days": window_days,
|
||||||
|
"range": {"from": str(start), "to": str(end)},
|
||||||
|
"by_model_horizon": [
|
||||||
|
{
|
||||||
|
"model": r[0],
|
||||||
|
"horizon": int(r[1]),
|
||||||
|
"n": int(r[2]),
|
||||||
|
"hit_rate": float(r[3]) if r[3] is not None else None,
|
||||||
|
"mae": float(r[4]) if r[4] is not None else None,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
58
backend/app/api/news.py
Normal file
58
backend/app/api/news.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""뉴스/공시 목록 API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/news", tags=["news"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}")
|
||||||
|
def list_news(
|
||||||
|
code: str,
|
||||||
|
limit: int = Query(default=20, ge=1, le=200),
|
||||||
|
source: str | None = Query(default=None, description="naver_finance / google_rss / dart"),
|
||||||
|
) -> dict:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
sym = conn.execute(
|
||||||
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not sym:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
where = "code = :c"
|
||||||
|
params: dict = {"c": code, "lim": limit}
|
||||||
|
if source:
|
||||||
|
where += " AND source = :src"
|
||||||
|
params["src"] = source
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
SELECT source, published_at, title, url, sentiment_score, sentiment_label
|
||||||
|
FROM news
|
||||||
|
WHERE {where}
|
||||||
|
ORDER BY published_at DESC
|
||||||
|
LIMIT :lim
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
).all()
|
||||||
|
return {
|
||||||
|
"code": sym[0],
|
||||||
|
"name": sym[1],
|
||||||
|
"count": len(rows),
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"source": r[0],
|
||||||
|
"published_at": r[1].isoformat() if r[1] else None,
|
||||||
|
"title": r[2],
|
||||||
|
"url": r[3],
|
||||||
|
"sentiment_score": float(r[4]) if r[4] is not None else None,
|
||||||
|
"sentiment_label": r[5],
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
137
backend/app/api/predict.py
Normal file
137
backend/app/api/predict.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""예측 API.
|
||||||
|
|
||||||
|
POST /api/predict/{code} : on-demand 예측 + predictions 적재 (user_triggered=TRUE)
|
||||||
|
GET /api/predict/{code}/latest : 가장 최근 base_date 의 예측 horizons 묶음 반환
|
||||||
|
(UI 가 새로고침해도 같은 결과 보여줄 때 사용)
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
from app.pipelines.predict_one import predict_and_store
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
KST = timezone(timedelta(hours=9))
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/predict", tags=["predict"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{code}")
|
||||||
|
def predict_endpoint(
|
||||||
|
code: str,
|
||||||
|
horizons: str = Query(default="1,3,5", description="쉼표 구분 정수. ex) '1,3,5'"),
|
||||||
|
user_triggered: bool = Query(default=True),
|
||||||
|
) -> dict:
|
||||||
|
"""on-demand 예측. UI 의 '예상차트 보기' 클릭이 호출."""
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
hs = tuple(int(x) for x in horizons.split(",") if x.strip())
|
||||||
|
if not hs or any(h < 1 or h > 30 for h in hs):
|
||||||
|
raise ValueError("invalid horizons")
|
||||||
|
except ValueError as exc:
|
||||||
|
raise HTTPException(status_code=400, detail=f"bad horizons: {exc}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = predict_and_store(code, horizons=hs, user_triggered=user_triggered)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
raise HTTPException(status_code=409, detail=str(exc))
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.exception("predict_and_store failed for %s", code)
|
||||||
|
raise HTTPException(status_code=500, detail=str(exc))
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}/latest")
|
||||||
|
def latest_prediction(code: str) -> dict:
|
||||||
|
"""가장 최근 base_date 에 대한 'ensemble' 예측 묶음."""
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
symbol = conn.execute(
|
||||||
|
text("SELECT code, name FROM symbols WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not symbol:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
|
||||||
|
latest_base = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT MAX(base_date) FROM predictions "
|
||||||
|
"WHERE code = :c AND model = 'ensemble'"
|
||||||
|
),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not latest_base or latest_base[0] is None:
|
||||||
|
return {"code": code, "found": False, "steps": []}
|
||||||
|
base_date = latest_base[0]
|
||||||
|
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT predicted_at, target_date, horizon, direction,
|
||||||
|
prob_up, prob_flat, prob_down, expected_return,
|
||||||
|
point_forecast, ci_low, ci_high, user_triggered,
|
||||||
|
features_snapshot
|
||||||
|
FROM predictions
|
||||||
|
WHERE code = :c AND base_date = :bd AND model = 'ensemble'
|
||||||
|
ORDER BY horizon
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "bd": base_date},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# base_close (그 날 ohlcv) 도 같이 반환 — UI 가 차트 마지막 점에 이어붙일 때 사용
|
||||||
|
base_close_row = conn.execute(
|
||||||
|
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
||||||
|
{"c": code, "d": base_date},
|
||||||
|
).first()
|
||||||
|
base_close = float(base_close_row[0]) if base_close_row and base_close_row[0] is not None else None
|
||||||
|
|
||||||
|
steps = []
|
||||||
|
for r in rows:
|
||||||
|
feats = r[12]
|
||||||
|
if isinstance(feats, str):
|
||||||
|
try:
|
||||||
|
feats = json.loads(feats)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
pass
|
||||||
|
steps.append(
|
||||||
|
{
|
||||||
|
"predicted_at": r[0].isoformat() if r[0] else None,
|
||||||
|
"target_date": str(r[1]),
|
||||||
|
"horizon": int(r[2]),
|
||||||
|
"direction": r[3],
|
||||||
|
"prob_up": float(r[4]) if r[4] is not None else None,
|
||||||
|
"prob_flat": float(r[5]) if r[5] is not None else None,
|
||||||
|
"prob_down": float(r[6]) if r[6] is not None else None,
|
||||||
|
"expected_return": float(r[7]) if r[7] is not None else None,
|
||||||
|
"point_close": float(r[8]) if r[8] is not None else None,
|
||||||
|
"ci_low": float(r[9]) if r[9] is not None else None,
|
||||||
|
"ci_high": float(r[10]) if r[10] is not None else None,
|
||||||
|
"user_triggered": bool(r[11]),
|
||||||
|
"features_snapshot": feats,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": symbol[0],
|
||||||
|
"name": symbol[1],
|
||||||
|
"found": True,
|
||||||
|
"base_date": str(base_date),
|
||||||
|
"base_close": base_close,
|
||||||
|
"steps": steps,
|
||||||
|
}
|
||||||
@@ -4,6 +4,10 @@ POST /api/refresh/{code}
|
|||||||
body: 없음
|
body: 없음
|
||||||
query: ?lookback_days=7 (기본)
|
query: ?lookback_days=7 (기본)
|
||||||
resp: refresh_one.RefreshReport.to_dict()
|
resp: refresh_one.RefreshReport.to_dict()
|
||||||
|
|
||||||
|
POST /api/refresh/seed/symbols
|
||||||
|
symbols 테이블 강제 재시드 (SEED 10 + KRX 전 종목). 부팅 시 시드가 실패한
|
||||||
|
경우 컨테이너 재기동 없이 복구하기 위한 admin 엔드포인트.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -11,6 +15,7 @@ from fastapi import APIRouter, HTTPException, Query
|
|||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from app.db.connection import get_engine
|
from app.db.connection import get_engine
|
||||||
|
from app.fetch.symbols_seed import seed_symbols
|
||||||
from app.pipelines.refresh_one import refresh_code
|
from app.pipelines.refresh_one import refresh_code
|
||||||
|
|
||||||
router = APIRouter(prefix="/api", tags=["refresh"])
|
router = APIRouter(prefix="/api", tags=["refresh"])
|
||||||
@@ -33,3 +38,35 @@ def refresh_endpoint(
|
|||||||
raise HTTPException(status_code=404, detail=f"unknown code: {code} (symbols 테이블에 없음. 시드 필요)")
|
raise HTTPException(status_code=404, detail=f"unknown code: {code} (symbols 테이블에 없음. 시드 필요)")
|
||||||
report = refresh_code(code, name, lookback_days=lookback_days)
|
report = refresh_code(code, name, lookback_days=lookback_days)
|
||||||
return report.to_dict()
|
return report.to_dict()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh/seed/symbols")
|
||||||
|
def reseed_symbols() -> dict:
|
||||||
|
"""symbols 테이블 강제 재시드.
|
||||||
|
|
||||||
|
호출 예 (Windows cmd):
|
||||||
|
curl -X POST http://localhost:8000/api/refresh/seed/symbols
|
||||||
|
|
||||||
|
KRX 가 주말/장 마감 시간에 비정상 응답을 줄 때도 SEED 10 종목은 항상 보장하므로
|
||||||
|
엔드포인트는 200 을 돌려준다. 부분 성공 정보는 응답 body 에 담아 사용자가 판단.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
report = seed_symbols()
|
||||||
|
return {
|
||||||
|
"ok": True,
|
||||||
|
"inserted": report.inserted,
|
||||||
|
"updated": report.updated,
|
||||||
|
"seed_marked": report.seed_marked,
|
||||||
|
"markets": report.markets,
|
||||||
|
}
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
# seed_symbols 내부에서 다 잡지만, 만에 하나 외부로 새는 예외 (logger 포매터
|
||||||
|
# 자체 버그 등) 도 200 으로 흡수해서 SEED 10 만이라도 살리는 게 UX 목표.
|
||||||
|
return {
|
||||||
|
"ok": False,
|
||||||
|
"inserted": 0,
|
||||||
|
"updated": 0,
|
||||||
|
"seed_marked": 0,
|
||||||
|
"markets": {},
|
||||||
|
"error": repr(e)[:300],
|
||||||
|
}
|
||||||
|
|||||||
101
backend/app/api/symbols.py
Normal file
101
backend/app/api/symbols.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""종목 검색 / 메타 API."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/symbols", tags=["symbols"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/search")
|
||||||
|
def search_symbols(
|
||||||
|
q: str = Query(..., min_length=1, max_length=40, description="종목명 또는 코드 prefix/부분 일치"),
|
||||||
|
limit: int = Query(default=20, ge=1, le=100),
|
||||||
|
seed_only: bool = Query(default=False, description="true 면 학습/배치 대상 10종목만"),
|
||||||
|
) -> dict:
|
||||||
|
"""이름은 trigram + ILIKE, 코드는 prefix 매치.
|
||||||
|
|
||||||
|
우선순위:
|
||||||
|
1) 코드가 정확히 같으면 가장 위
|
||||||
|
2) 이름 prefix 매치
|
||||||
|
3) 이름 부분 매치 (trigram similarity)
|
||||||
|
"""
|
||||||
|
q_norm = q.strip()
|
||||||
|
if not q_norm:
|
||||||
|
raise HTTPException(status_code=400, detail="empty query")
|
||||||
|
|
||||||
|
eng = get_engine()
|
||||||
|
where_seed = "AND is_seed = TRUE" if seed_only else ""
|
||||||
|
sql = text(
|
||||||
|
f"""
|
||||||
|
WITH ranked AS (
|
||||||
|
SELECT code, name, market, sector, is_seed,
|
||||||
|
CASE
|
||||||
|
WHEN code = :q THEN 0
|
||||||
|
WHEN code LIKE :prefix THEN 1
|
||||||
|
WHEN name LIKE :prefix THEN 2
|
||||||
|
WHEN name ILIKE :contains THEN 3
|
||||||
|
ELSE 4
|
||||||
|
END AS rank,
|
||||||
|
similarity(name, :q) AS sim
|
||||||
|
FROM symbols
|
||||||
|
WHERE active = TRUE
|
||||||
|
{where_seed}
|
||||||
|
AND (code LIKE :prefix OR name ILIKE :contains OR similarity(name, :q) > 0.2)
|
||||||
|
)
|
||||||
|
SELECT code, name, market, sector, is_seed
|
||||||
|
FROM ranked
|
||||||
|
ORDER BY rank ASC, sim DESC, name ASC
|
||||||
|
LIMIT :lim
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
sql,
|
||||||
|
{
|
||||||
|
"q": q_norm,
|
||||||
|
"prefix": f"{q_norm}%",
|
||||||
|
"contains": f"%{q_norm}%",
|
||||||
|
"lim": limit,
|
||||||
|
},
|
||||||
|
).all()
|
||||||
|
return {
|
||||||
|
"q": q_norm,
|
||||||
|
"count": len(rows),
|
||||||
|
"items": [
|
||||||
|
{
|
||||||
|
"code": r[0],
|
||||||
|
"name": r[1],
|
||||||
|
"market": r[2],
|
||||||
|
"sector": r[3],
|
||||||
|
"is_seed": bool(r[4]),
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{code}")
|
||||||
|
def get_symbol(code: str) -> dict:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT code, name, market, sector, is_seed, active, created_at "
|
||||||
|
"FROM symbols WHERE code = :c"
|
||||||
|
),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
if not row:
|
||||||
|
raise HTTPException(status_code=404, detail=f"unknown code: {code}")
|
||||||
|
return {
|
||||||
|
"code": row[0],
|
||||||
|
"name": row[1],
|
||||||
|
"market": row[2],
|
||||||
|
"sector": row[3],
|
||||||
|
"is_seed": bool(row[4]),
|
||||||
|
"active": bool(row[5]),
|
||||||
|
"created_at": str(row[6]) if row[6] else None,
|
||||||
|
}
|
||||||
53
backend/app/db/migrate.py
Normal file
53
backend/app/db/migrate.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Manual migration runner.
|
||||||
|
|
||||||
|
docker-entrypoint-initdb.d 는 fresh DB 첫 기동 때만 동작. 이미 동작 중인 DB 에
|
||||||
|
새 마이그레이션을 적용하려면 이 스크립트로:
|
||||||
|
|
||||||
|
python -m app.db.migrate
|
||||||
|
|
||||||
|
모든 SQL 파일은 idempotent (CREATE IF NOT EXISTS / CREATE OR REPLACE) 여야 함.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MIGRATIONS_DIR = Path(__file__).parent / "migrations"
|
||||||
|
|
||||||
|
|
||||||
|
def apply_all() -> dict[str, str]:
|
||||||
|
"""migrations/ 안 .sql 들을 이름순으로 적용. 결과: {filename: 'ok'|'failed: ...'}."""
|
||||||
|
eng = get_engine()
|
||||||
|
results: dict[str, str] = {}
|
||||||
|
files = sorted(MIGRATIONS_DIR.glob("*.sql"))
|
||||||
|
if not files:
|
||||||
|
logger.warning("no migration files in %s", MIGRATIONS_DIR)
|
||||||
|
return results
|
||||||
|
for path in files:
|
||||||
|
sql = path.read_text(encoding="utf-8")
|
||||||
|
# psql meta-command 제거 (\set ON_ERROR_STOP 등)
|
||||||
|
cleaned = "\n".join(
|
||||||
|
ln for ln in sql.splitlines() if not ln.strip().startswith("\\")
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with eng.begin() as conn:
|
||||||
|
conn.execute(text(cleaned))
|
||||||
|
results[path.name] = "ok"
|
||||||
|
logger.info("applied %s", path.name)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
results[path.name] = f"failed: {exc}"
|
||||||
|
logger.exception("migration %s failed", path.name)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
out = apply_all()
|
||||||
|
for k, v in out.items():
|
||||||
|
print(f"{k}: {v}")
|
||||||
32
backend/app/db/migrations/002_sentiment_view.sql
Normal file
32
backend/app/db/migrations/002_sentiment_view.sql
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
-- Phase 2: 일별 종목별 감성 집계 뷰.
|
||||||
|
-- weighted_score : 소스별 가중치 적용
|
||||||
|
-- naver_finance 1.0 (가장 직접적인 종목 페이지 뉴스)
|
||||||
|
-- google_rss 0.7 (관련성 노이즈 있음)
|
||||||
|
-- dart 0.5 (공시는 short title 만으로는 감성이 약함)
|
||||||
|
|
||||||
|
\set ON_ERROR_STOP on
|
||||||
|
|
||||||
|
CREATE OR REPLACE VIEW v_sentiment_daily AS
|
||||||
|
SELECT
|
||||||
|
code,
|
||||||
|
(published_at AT TIME ZONE 'Asia/Seoul')::date AS date,
|
||||||
|
COUNT(*) AS n_articles,
|
||||||
|
AVG(sentiment_score)::REAL AS mean_score,
|
||||||
|
AVG(CASE WHEN sentiment_label = 'positive' THEN 1.0 ELSE 0.0 END)::REAL AS pos_ratio,
|
||||||
|
AVG(CASE WHEN sentiment_label = 'negative' THEN 1.0 ELSE 0.0 END)::REAL AS neg_ratio,
|
||||||
|
AVG(CASE WHEN sentiment_label = 'neutral' THEN 1.0 ELSE 0.0 END)::REAL AS neu_ratio,
|
||||||
|
AVG(
|
||||||
|
sentiment_score * CASE source
|
||||||
|
WHEN 'naver_finance' THEN 1.0
|
||||||
|
WHEN 'google_rss' THEN 0.7
|
||||||
|
WHEN 'dart' THEN 0.5
|
||||||
|
ELSE 0.6
|
||||||
|
END
|
||||||
|
)::REAL AS weighted_score
|
||||||
|
FROM news
|
||||||
|
WHERE sentiment_score IS NOT NULL
|
||||||
|
AND code IS NOT NULL
|
||||||
|
GROUP BY code, (published_at AT TIME ZONE 'Asia/Seoul')::date;
|
||||||
|
|
||||||
|
COMMENT ON VIEW v_sentiment_daily IS
|
||||||
|
'Phase 2: KR-FinBERT 점수를 종목·일(KST) 단위로 집계. Phase 4 LGBM 피처 + UI 차트 보조 데이터로 사용.';
|
||||||
19
backend/app/db/migrations/003_ensemble_weights.sql
Normal file
19
backend/app/db/migrations/003_ensemble_weights.sql
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
-- Phase 4: 앙상블 가중치 저장.
|
||||||
|
-- (code, horizon) 별로 Chronos vs LGBM 가중치. 일요일 02:00 재학습 잡에서 갱신.
|
||||||
|
|
||||||
|
\set ON_ERROR_STOP on
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS ensemble_weights (
|
||||||
|
code TEXT NOT NULL REFERENCES symbols(code),
|
||||||
|
horizon INT NOT NULL,
|
||||||
|
w_chronos REAL NOT NULL DEFAULT 0.6,
|
||||||
|
w_lgbm REAL NOT NULL DEFAULT 0.4,
|
||||||
|
hit_rate_chronos REAL,
|
||||||
|
hit_rate_lgbm REAL,
|
||||||
|
sample_count INT,
|
||||||
|
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||||
|
PRIMARY KEY (code, horizon)
|
||||||
|
);
|
||||||
|
|
||||||
|
COMMENT ON TABLE ensemble_weights IS
|
||||||
|
'Phase 4: (code, horizon) 별 Chronos/LGBM 가중치. 최근 30일 prediction_outcomes hit_rate 기반 매주 갱신.';
|
||||||
@@ -13,11 +13,14 @@ status='skipped_missing_key' 처리.
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import date, datetime
|
from datetime import date, datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
@@ -30,6 +33,16 @@ logger = logging.getLogger(__name__)
|
|||||||
KIS_BASE = "https://openapi.koreainvestment.com:9443"
|
KIS_BASE = "https://openapi.koreainvestment.com:9443"
|
||||||
USER_AGENT = "stock_chart_site/0.1 (+personal)"
|
USER_AGENT = "stock_chart_site/0.1 (+personal)"
|
||||||
|
|
||||||
|
# 토큰 디스크 캐시 경로. 기본값은 컨테이너 안 /app/.cache/kis_token.json — docker-compose
|
||||||
|
# 의 `./backend:/app` 바인드 마운트 덕에 호스트 `./backend/.cache/` 에 영속된다.
|
||||||
|
# `backend/.cache/` 는 .gitignore 에 들어있어 secrets 가 커밋되지 않는다.
|
||||||
|
#
|
||||||
|
# 왜 디스크 캐시가 필요한가:
|
||||||
|
# KIS 는 access_token 발급을 1분 1회, 하루 N회로 강하게 제한한다. 메모리만 쓰면
|
||||||
|
# `restart.bat` / `build.bat` / 컨테이너 재기동 때마다 새 발급 → 403 (EGW00133 등) 빈발.
|
||||||
|
# 토큰 자체는 24시간 유효하므로, 컨테이너 인스턴스가 바뀌어도 같은 토큰을 재사용한다.
|
||||||
|
_TOKEN_CACHE_PATH = Path(os.environ.get("KIS_TOKEN_CACHE_PATH", "/app/.cache/kis_token.json"))
|
||||||
|
|
||||||
|
|
||||||
class SkippedMissingKey(RuntimeError):
|
class SkippedMissingKey(RuntimeError):
|
||||||
"""KIS 키 미설정 시 발생. 호출 측에서 skipped 로 매핑."""
|
"""KIS 키 미설정 시 발생. 호출 측에서 skipped 로 매핑."""
|
||||||
@@ -49,6 +62,54 @@ def _has_keys() -> bool:
|
|||||||
return bool(settings.kis_app_key and settings.kis_app_secret)
|
return bool(settings.kis_app_key and settings.kis_app_secret)
|
||||||
|
|
||||||
|
|
||||||
|
def _current_key_prefix() -> str:
|
||||||
|
# app_key 가 바뀌었는데 옛 키로 받은 토큰을 그대로 쓰면 401. 캐시 무효화 키로 사용.
|
||||||
|
return (settings.kis_app_key or "")[:8]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_disk_cache() -> _Token | None:
|
||||||
|
try:
|
||||||
|
with _TOKEN_CACHE_PATH.open() as f:
|
||||||
|
data = json.load(f)
|
||||||
|
if data.get("key_prefix") != _current_key_prefix():
|
||||||
|
# .env 에서 app_key 가 바뀌었을 가능성 → 캐시 폐기
|
||||||
|
return None
|
||||||
|
tok = _Token(value=str(data["value"]), expires_at=float(data["expires_at"]))
|
||||||
|
if tok.expires_at <= time.time():
|
||||||
|
return None
|
||||||
|
return tok
|
||||||
|
except FileNotFoundError:
|
||||||
|
return None
|
||||||
|
except (OSError, ValueError, KeyError, TypeError) as exc:
|
||||||
|
logger.warning("kis token disk-cache read failed (%s): %s", _TOKEN_CACHE_PATH, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _save_disk_cache(tok: _Token) -> None:
|
||||||
|
try:
|
||||||
|
_TOKEN_CACHE_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
tmp = _TOKEN_CACHE_PATH.with_suffix(".json.tmp")
|
||||||
|
# atomic write: 부분 쓰기 중 컨테이너가 죽어도 다음 시작 시 깨진 파일 안 읽음
|
||||||
|
with tmp.open("w") as f:
|
||||||
|
json.dump(
|
||||||
|
{
|
||||||
|
"value": tok.value,
|
||||||
|
"expires_at": tok.expires_at,
|
||||||
|
"key_prefix": _current_key_prefix(),
|
||||||
|
},
|
||||||
|
f,
|
||||||
|
)
|
||||||
|
os.replace(tmp, _TOKEN_CACHE_PATH)
|
||||||
|
# 토큰 파일은 키 동등의 secret. 0600 권한.
|
||||||
|
try:
|
||||||
|
os.chmod(_TOKEN_CACHE_PATH, 0o600)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
except OSError as exc:
|
||||||
|
# 캐시 쓰기 실패는 치명적이지 않음 — 메모리 캐시로만 동작 가능. 경고만.
|
||||||
|
logger.warning("kis token disk-cache write failed (%s): %s", _TOKEN_CACHE_PATH, exc)
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(3),
|
||||||
wait=wait_exponential(multiplier=1, min=1, max=8),
|
wait=wait_exponential(multiplier=1, min=1, max=8),
|
||||||
@@ -75,13 +136,29 @@ def _issue_token() -> _Token:
|
|||||||
|
|
||||||
|
|
||||||
def get_token() -> str:
|
def get_token() -> str:
|
||||||
"""캐시된 토큰 반환. 만료 60초 전부터 재발급. 키 없으면 SkippedMissingKey."""
|
"""캐시된 토큰 반환. 메모리 → 디스크 → 신규 발급 순. 키 없으면 SkippedMissingKey.
|
||||||
|
|
||||||
|
디스크 캐시는 컨테이너 재기동 시 토큰 재발급 1분 제한 (EGW00133) 회피용.
|
||||||
|
"""
|
||||||
global _token_cache
|
global _token_cache
|
||||||
with _token_lock:
|
with _token_lock:
|
||||||
if _token_cache and _token_cache.expires_at > time.time():
|
if _token_cache and _token_cache.expires_at > time.time():
|
||||||
return _token_cache.value
|
return _token_cache.value
|
||||||
|
disk = _load_disk_cache()
|
||||||
|
if disk is not None:
|
||||||
|
_token_cache = disk
|
||||||
|
logger.info(
|
||||||
|
"kis token loaded from disk, expires_at=%s",
|
||||||
|
datetime.fromtimestamp(disk.expires_at),
|
||||||
|
)
|
||||||
|
return disk.value
|
||||||
_token_cache = _issue_token()
|
_token_cache = _issue_token()
|
||||||
logger.info("kis token issued, expires_at=%s", datetime.fromtimestamp(_token_cache.expires_at))
|
_save_disk_cache(_token_cache)
|
||||||
|
logger.info(
|
||||||
|
"kis token issued (and cached to %s), expires_at=%s",
|
||||||
|
_TOKEN_CACHE_PATH,
|
||||||
|
datetime.fromtimestamp(_token_cache.expires_at),
|
||||||
|
)
|
||||||
return _token_cache.value
|
return _token_cache.value
|
||||||
|
|
||||||
|
|
||||||
@@ -156,6 +233,130 @@ def fetch_daily_price(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
stop=stop_after_attempt(2),
|
||||||
|
wait=wait_exponential(multiplier=1, min=1, max=4),
|
||||||
|
retry=retry_if_exception_type((httpx.HTTPError, httpx.TimeoutException)),
|
||||||
|
reraise=True,
|
||||||
|
)
|
||||||
|
def fetch_minute_price(
|
||||||
|
code: str,
|
||||||
|
*,
|
||||||
|
end_hour: str | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""당일 1분봉 시세 조회 (read-only). 최신 30개 캔들을 반환.
|
||||||
|
|
||||||
|
KIS 분봉 endpoint (`inquire-time-itemchartprice`) 는 base 시각 (FID_INPUT_HOUR_1) 부터
|
||||||
|
역순으로 최대 30개의 1분봉을 돌려준다. base 를 비우면 KIS 가 가장 최근 시각으로 해석.
|
||||||
|
즉 장중 호출 → 직전 30분 / 장 종료 후 호출 → 15:00~15:30 의 30분.
|
||||||
|
|
||||||
|
Returns: [{ts: datetime(KST aware), open, high, low, close, volume}, ...]
|
||||||
|
ts 오름차순 정렬.
|
||||||
|
|
||||||
|
Note: 이 endpoint 는 "당일" 분봉만 지원. 어제 이전 분봉은 별도 endpoint 가 필요한데,
|
||||||
|
이 사이트의 사용 패턴 (장중 라이브 차트) 에는 당일 데이터로 충분하다.
|
||||||
|
"""
|
||||||
|
if not _has_keys():
|
||||||
|
raise SkippedMissingKey("kis app_key/secret missing")
|
||||||
|
url = f"{KIS_BASE}/uapi/domestic-stock/v1/quotations/inquire-time-itemchartprice"
|
||||||
|
params = {
|
||||||
|
"FID_ETC_CLS_CODE": "",
|
||||||
|
"FID_COND_MRKT_DIV_CODE": "J",
|
||||||
|
"FID_INPUT_ISCD": code,
|
||||||
|
# 비우면 KIS 가 "지금" 으로 해석. 장 마감 후엔 15:30:00 부근 데이터.
|
||||||
|
"FID_INPUT_HOUR_1": end_hour or "",
|
||||||
|
"FID_PW_DATA_INCU_YN": "Y", # 과거 데이터 포함 (장 시작 직후 빈 데이터 방지)
|
||||||
|
}
|
||||||
|
with httpx.Client(timeout=15.0) as cli:
|
||||||
|
resp = cli.get(url, headers=_headers("FHKST03010200"), params=params)
|
||||||
|
resp.raise_for_status()
|
||||||
|
data = resp.json()
|
||||||
|
if data.get("rt_cd") != "0":
|
||||||
|
raise RuntimeError(f"kis error: {data.get('msg1')} (rt_cd={data.get('rt_cd')})")
|
||||||
|
|
||||||
|
# KIS 응답은 KST. tz-aware 로 변환해서 DB (TIMESTAMPTZ) 에 안전 적재.
|
||||||
|
from datetime import timedelta, timezone as _tz
|
||||||
|
KST = _tz(timedelta(hours=9))
|
||||||
|
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for row in data.get("output2", []) or []:
|
||||||
|
raw_date = row.get("stck_bsop_date")
|
||||||
|
raw_hour = row.get("stck_cntg_hour")
|
||||||
|
if not raw_date or not raw_hour:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
ts = datetime.strptime(raw_date + raw_hour.zfill(6), "%Y%m%d%H%M%S").replace(tzinfo=KST)
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
out.append(
|
||||||
|
{
|
||||||
|
"ts": ts,
|
||||||
|
"open": float(row.get("stck_oprc") or 0),
|
||||||
|
"high": float(row.get("stck_hgpr") or 0),
|
||||||
|
"low": float(row.get("stck_lwpr") or 0),
|
||||||
|
# 분봉에서는 종가가 stck_prpr (현재가) 로 옴
|
||||||
|
"close": float(row.get("stck_prpr") or row.get("stck_clpr") or 0),
|
||||||
|
"volume": int(row.get("cntg_vol") or 0),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# KIS 응답은 보통 최신→과거 역순. UI/DB 적재 편의를 위해 오름차순으로 뒤집는다.
|
||||||
|
out.sort(key=lambda r: r["ts"])
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_minute_range(
|
||||||
|
code: str,
|
||||||
|
from_ts: datetime,
|
||||||
|
to_ts: datetime,
|
||||||
|
*,
|
||||||
|
max_pages: int = 20,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""[from_ts, to_ts] 윈도우의 1분봉 전체. KIS 30-bar 페이지를 역순으로 반복 호출.
|
||||||
|
|
||||||
|
KIS `inquire-time-itemchartprice` 는 한 번에 최대 30개 1분봉만 주고,
|
||||||
|
`FID_INPUT_HOUR_1` 기준 그 시각 포함 이전 30분을 반환한다. 그래서 to_ts 부터
|
||||||
|
시작해서 가장 이른 응답 시각의 -1분을 다음 cursor 로 잡아 from_ts 까지 후퇴.
|
||||||
|
|
||||||
|
중복 키 (ts) 는 dict 로 자연 dedupe. 더 이상 새 행이 안 들어오거나 max_pages 도달
|
||||||
|
하면 종료 (rate-limit/무한루프 방지).
|
||||||
|
|
||||||
|
Note: 이 endpoint 는 "당일" 만 지원. from_ts/to_ts 는 같은 날짜여야 한다.
|
||||||
|
"""
|
||||||
|
if not _has_keys():
|
||||||
|
raise SkippedMissingKey("kis app_key/secret missing")
|
||||||
|
if from_ts >= to_ts:
|
||||||
|
return []
|
||||||
|
|
||||||
|
from datetime import timedelta as _td
|
||||||
|
|
||||||
|
accumulated: dict[datetime, dict[str, Any]] = {}
|
||||||
|
cursor = to_ts
|
||||||
|
pages = 0
|
||||||
|
while cursor > from_ts and pages < max_pages:
|
||||||
|
pages += 1
|
||||||
|
rows = fetch_minute_price(code, end_hour=cursor.strftime("%H%M%S"))
|
||||||
|
if not rows:
|
||||||
|
break
|
||||||
|
added = 0
|
||||||
|
for r in rows:
|
||||||
|
ts = r["ts"]
|
||||||
|
if ts < from_ts or ts > to_ts:
|
||||||
|
continue
|
||||||
|
if ts not in accumulated:
|
||||||
|
accumulated[ts] = r
|
||||||
|
added += 1
|
||||||
|
if added == 0:
|
||||||
|
# 같은 30개를 또 받았다 — 더 과거가 없거나 KIS 가 똑같은 페이지를 반환.
|
||||||
|
break
|
||||||
|
earliest_ts = min(r["ts"] for r in rows)
|
||||||
|
next_cursor = earliest_ts - _td(minutes=1)
|
||||||
|
if next_cursor >= cursor:
|
||||||
|
break
|
||||||
|
cursor = next_cursor
|
||||||
|
|
||||||
|
return sorted(accumulated.values(), key=lambda r: r["ts"])
|
||||||
|
|
||||||
|
|
||||||
def ping() -> dict[str, Any]:
|
def ping() -> dict[str, Any]:
|
||||||
"""토큰 발급만 시도해서 키 유효성 확인."""
|
"""토큰 발급만 시도해서 키 유효성 확인."""
|
||||||
if not _has_keys():
|
if not _has_keys():
|
||||||
|
|||||||
@@ -41,21 +41,70 @@ def _fetch_market_listing(market: str) -> list[tuple[str, str]]:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _upsert_seed_tickers() -> int:
|
||||||
|
"""SEED 10종목 강제 upsert. 네트워크 불필요 → KRX 실패와 무관하게 항상 성공.
|
||||||
|
|
||||||
|
별도 트랜잭션이라 KRX 시드가 나중에 실패해도 살아남는다.
|
||||||
|
"""
|
||||||
|
engine = get_engine()
|
||||||
|
with engine.begin() as conn:
|
||||||
|
for t in SEED_TICKERS:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO symbols (code, name, market, is_seed)
|
||||||
|
VALUES (:code, :name, :market, TRUE)
|
||||||
|
ON CONFLICT (code) DO UPDATE
|
||||||
|
SET name = EXCLUDED.name,
|
||||||
|
market = EXCLUDED.market,
|
||||||
|
is_seed = TRUE
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"code": t.code, "name": t.name, "market": t.market},
|
||||||
|
)
|
||||||
|
return len(SEED_TICKERS)
|
||||||
|
|
||||||
|
|
||||||
def seed_symbols() -> SeedReport:
|
def seed_symbols() -> SeedReport:
|
||||||
"""KOSPI + KOSDAQ 전 종목을 upsert. SEED 10 종목은 is_seed=TRUE."""
|
"""KOSPI + KOSDAQ 전 종목을 upsert. SEED 10 종목은 is_seed=TRUE.
|
||||||
rows: list[tuple[str, str, str]] = [] # (code, name, market)
|
|
||||||
|
순서:
|
||||||
|
1) SEED_TICKERS 먼저 별도 트랜잭션으로 강제 upsert (KRX 실패와 무관하게 검색 가능)
|
||||||
|
2) KRX 리스팅 fetch (네트워크 의존) → 별도 트랜잭션으로 일괄 upsert.
|
||||||
|
시장별 fetch 실패 시 해당 시장만 스킵하고 나머지 진행.
|
||||||
|
"""
|
||||||
|
# 1) SEED_TICKERS — 항상 보장
|
||||||
|
try:
|
||||||
|
_upsert_seed_tickers()
|
||||||
|
seed_marked = len(SEED_TICKERS)
|
||||||
|
logger.info("seed_symbols: seed-tickers upserted (%d)", seed_marked)
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
# logger.exception 은 Python 3.11 의 traceback 포매터가 pykrx 소스의 한글 주석
|
||||||
|
# 'df = 가...' 바이트를 만나면 UnicodeDecodeError 를 던지는 버그가 있어, 그 예외가
|
||||||
|
# try 밖으로 escape 해서 500 을 만든다. 그래서 traceback 안 찍는다.
|
||||||
|
logger.error("seed_symbols: seed-tickers upsert failed: %s", repr(e)[:300])
|
||||||
|
seed_marked = 0
|
||||||
|
|
||||||
|
# 2) KRX 전 종목 — fetch 실패해도 부분 성공 허용
|
||||||
market_counts: dict[str, int] = {}
|
market_counts: dict[str, int] = {}
|
||||||
|
all_rows: list[tuple[str, str, str]] = []
|
||||||
for market in ("KOSPI", "KOSDAQ"):
|
for market in ("KOSPI", "KOSDAQ"):
|
||||||
|
try:
|
||||||
listing = _fetch_market_listing(market)
|
listing = _fetch_market_listing(market)
|
||||||
market_counts[market] = len(listing)
|
market_counts[market] = len(listing)
|
||||||
for code, name in listing:
|
for code, name in listing:
|
||||||
rows.append((code, name, market))
|
all_rows.append((code, name, market))
|
||||||
|
logger.info("seed_symbols: KRX %s fetched (%d)", market, len(listing))
|
||||||
|
except Exception as e: # noqa: BLE001
|
||||||
|
logger.error("seed_symbols: KRX %s fetch failed — skip market: %s", market, repr(e)[:300])
|
||||||
|
market_counts[market] = 0
|
||||||
|
|
||||||
engine = get_engine()
|
|
||||||
inserted = updated = 0
|
inserted = updated = 0
|
||||||
seed_marked = 0
|
if all_rows:
|
||||||
|
engine = get_engine()
|
||||||
|
try:
|
||||||
with engine.begin() as conn:
|
with engine.begin() as conn:
|
||||||
for code, name, market in rows:
|
for code, name, market in all_rows:
|
||||||
is_seed = code in SEED_CODES
|
is_seed = code in SEED_CODES
|
||||||
res = conn.execute(
|
res = conn.execute(
|
||||||
text(
|
text(
|
||||||
@@ -76,21 +125,8 @@ def seed_symbols() -> SeedReport:
|
|||||||
inserted += 1
|
inserted += 1
|
||||||
else:
|
else:
|
||||||
updated += 1
|
updated += 1
|
||||||
if is_seed:
|
except Exception as e: # noqa: BLE001
|
||||||
seed_marked += 1
|
logger.error("seed_symbols: KRX bulk upsert failed (transaction rolled back): %s", repr(e)[:300])
|
||||||
|
|
||||||
# SEED_TICKERS 중 KRX 리스팅에 없으면 (상장폐지 등) 그래도 명시적으로 시드 row 보장
|
|
||||||
for t in SEED_TICKERS:
|
|
||||||
conn.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
INSERT INTO symbols (code, name, market, is_seed)
|
|
||||||
VALUES (:code, :name, :market, TRUE)
|
|
||||||
ON CONFLICT (code) DO UPDATE SET is_seed = TRUE
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{"code": t.code, "name": t.name, "market": t.market},
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"seed_symbols done: inserted=%d updated=%d seed_marked=%d markets=%s",
|
"seed_symbols done: inserted=%d updated=%d seed_marked=%d markets=%s",
|
||||||
|
|||||||
@@ -1,14 +1,21 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.api.chart import router as chart_router
|
||||||
|
from app.api.metrics import router as metrics_router
|
||||||
|
from app.api.news import router as news_router
|
||||||
|
from app.api.predict import router as predict_router
|
||||||
from app.api.refresh import router as refresh_router
|
from app.api.refresh import router as refresh_router
|
||||||
|
from app.api.symbols import router as symbols_router
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.db.connection import ping as db_ping
|
from app.db.connection import get_engine, ping as db_ping
|
||||||
from app.fetch import dart as dart_mod
|
from app.fetch import dart as dart_mod
|
||||||
from app.fetch import kis as kis_mod
|
from app.fetch import kis as kis_mod
|
||||||
from app.pipelines.scheduler import shutdown_scheduler, start_scheduler
|
from app.pipelines.scheduler import shutdown_scheduler, start_scheduler
|
||||||
@@ -20,9 +27,58 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _bootstrap_db() -> None:
|
||||||
|
"""첫 부팅 자동화:
|
||||||
|
1) migrations/*.sql idempotent 적용 (timescale/pgvector 확장 + 스키마)
|
||||||
|
2) symbols 테이블 비어있으면 pykrx 로 전 종목 시드 (SEED 10 마크 포함)
|
||||||
|
|
||||||
|
BOOTSTRAP_DISABLED=1 이면 스킵 (테스트/CI 용). 어떤 단계든 실패해도 서버는
|
||||||
|
뜬다 — /health/db 가 진단을 알려준다.
|
||||||
|
"""
|
||||||
|
if os.environ.get("BOOTSTRAP_DISABLED") == "1":
|
||||||
|
logger.info("bootstrap skipped (BOOTSTRAP_DISABLED=1)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1) migrations
|
||||||
|
try:
|
||||||
|
from app.db.migrate import apply_all
|
||||||
|
res = apply_all()
|
||||||
|
logger.info("bootstrap migrate: %s", res)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("bootstrap migrate failed")
|
||||||
|
return # 스키마 없으면 시드 불가
|
||||||
|
|
||||||
|
# 2) symbols 시드
|
||||||
|
# - SEED 10종목은 매 부팅마다 무조건 upsert (10회 upsert, ms 단위, 네트워크 무관)
|
||||||
|
# → KRX 접근 실패한 환경에서도 최소 10종목 검색 보장
|
||||||
|
# - KRX 전 종목 fetch 는 symbols 가 비어있을 때만 (호출 비용 큼)
|
||||||
|
try:
|
||||||
|
from app.fetch.symbols_seed import _upsert_seed_tickers, seed_symbols
|
||||||
|
n_seed = _upsert_seed_tickers()
|
||||||
|
logger.info("bootstrap seed-tickers ensured (%d)", n_seed)
|
||||||
|
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(text("SELECT COUNT(*) FROM symbols")).first()
|
||||||
|
count = int(row[0]) if row else 0
|
||||||
|
if count <= n_seed:
|
||||||
|
# symbols 가 SEED 만큼 또는 그 이하 → KRX 전 종목 fetch 시도
|
||||||
|
logger.info("symbols sparse (count=%d) — running KRX listing seed", count)
|
||||||
|
report = seed_symbols()
|
||||||
|
logger.info("bootstrap seed_symbols: %s", report)
|
||||||
|
else:
|
||||||
|
logger.info("symbols already populated (count=%d) — skip KRX listing seed", count)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("bootstrap seed_symbols failed")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_: FastAPI):
|
async def lifespan(_: FastAPI):
|
||||||
|
_bootstrap_db()
|
||||||
# 스케줄러는 옵션. CI/테스트에서 disable 하고 싶으면 SCHEDULER_DISABLED 같은 env 추가 가능.
|
# 스케줄러는 옵션. CI/테스트에서 disable 하고 싶으면 SCHEDULER_DISABLED 같은 env 추가 가능.
|
||||||
|
if os.environ.get("SCHEDULER_DISABLED") == "1":
|
||||||
|
logger.info("scheduler skipped (SCHEDULER_DISABLED=1)")
|
||||||
|
else:
|
||||||
try:
|
try:
|
||||||
start_scheduler()
|
start_scheduler()
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
@@ -41,6 +97,11 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(refresh_router)
|
app.include_router(refresh_router)
|
||||||
|
app.include_router(symbols_router)
|
||||||
|
app.include_router(chart_router)
|
||||||
|
app.include_router(predict_router)
|
||||||
|
app.include_router(metrics_router)
|
||||||
|
app.include_router(news_router)
|
||||||
|
|
||||||
|
|
||||||
def _resolved_device() -> str:
|
def _resolved_device() -> str:
|
||||||
@@ -71,3 +132,30 @@ def health_keys() -> dict[str, object]:
|
|||||||
"dart": dart_mod.ping(),
|
"dart": dart_mod.ping(),
|
||||||
# huggingface 는 모델 다운로드 시점에 확인 (별도 ping 호출 비용 회피)
|
# huggingface 는 모델 다운로드 시점에 확인 (별도 ping 호출 비용 회피)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health/models")
|
||||||
|
def health_models() -> dict[str, object]:
|
||||||
|
"""Chronos / LGBM 가용성 진단.
|
||||||
|
|
||||||
|
Chronos: lazy 로드 첫 호출이라 30초~수 분 걸릴 수 있음 (HuggingFace 다운로드).
|
||||||
|
LGBM: 체크포인트 디렉토리 스캔 — retrain 안 돈 cold start 에선 비어있음.
|
||||||
|
"""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.models import chronos as chronos_mod
|
||||||
|
|
||||||
|
lgbm_dir = Path(os.environ.get("LGBM_MODEL_DIR", "/app/data/models"))
|
||||||
|
lgbm_files: list[str] = []
|
||||||
|
if lgbm_dir.exists():
|
||||||
|
lgbm_files = sorted(p.name for p in lgbm_dir.glob("*.pkl"))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"chronos": chronos_mod.ping(),
|
||||||
|
"lgbm": {
|
||||||
|
"model_dir": str(lgbm_dir),
|
||||||
|
"checkpoint_count": len(lgbm_files),
|
||||||
|
"samples": lgbm_files[:8], # 너무 많으면 잘라서.
|
||||||
|
"status": "ok" if lgbm_files else "no_checkpoints (cold start, run retrain_weekly)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|||||||
169
backend/app/models/chronos.py
Normal file
169
backend/app/models/chronos.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Chronos zero-shot 시계열 예측 어댑터.
|
||||||
|
|
||||||
|
모델: amazon/chronos-t5-small (46M, 빠르고 RTX 3070 Ti 에 충분히 들어감).
|
||||||
|
환경변수 CHRONOS_MODEL 로 base/large 로 바꿀 수 있음.
|
||||||
|
|
||||||
|
입력: 종가 시계열 (list[float], 최소 32 step).
|
||||||
|
출력: horizon 일 quantile forecast (q10/median/q90).
|
||||||
|
|
||||||
|
lazy singleton 으로 첫 호출 시 모델 로드. 디바이스는 settings.model_device 따라.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_NAME = os.environ.get("CHRONOS_MODEL", "amazon/chronos-t5-small")
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_state: dict[str, object] = {"loaded": False, "pipe": None, "device": None}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChronosForecast:
|
||||||
|
horizon: int
|
||||||
|
median: list[float]
|
||||||
|
q10: list[float]
|
||||||
|
q90: list[float]
|
||||||
|
samples: list[list[float]] # raw samples for ensemble downstream
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_device() -> str:
|
||||||
|
import torch # lazy
|
||||||
|
|
||||||
|
pref = (settings.model_device or "auto").lower()
|
||||||
|
if pref == "cuda":
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if pref == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _load() -> None:
|
||||||
|
global _state
|
||||||
|
with _lock:
|
||||||
|
if _state["loaded"]:
|
||||||
|
return
|
||||||
|
import torch
|
||||||
|
from chronos import ChronosPipeline
|
||||||
|
|
||||||
|
token = settings.huggingface_token or None
|
||||||
|
if token:
|
||||||
|
os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token)
|
||||||
|
os.environ.setdefault("HF_TOKEN", token)
|
||||||
|
|
||||||
|
device = _resolve_device()
|
||||||
|
# dtype 선택:
|
||||||
|
# - 이전엔 cuda 면 무조건 bf16 으로 갔는데, torch 2.3.1+cu121 사전빌드 wheel 이
|
||||||
|
# sm_86 (RTX 3070 Ti) 의 일부 T5 커널 binary 를 빠뜨려서 inference 첫 호출에
|
||||||
|
# "no kernel image is available for execution on the device" 발생. ping/load
|
||||||
|
# 까지는 통과해서 진단이 까다로웠음 (실제 005930 케이스에서 관측).
|
||||||
|
# - chronos-t5-small 은 46M params 라 fp32 로도 8GB VRAM 에 여유 충분, 속도
|
||||||
|
# 차이도 일봉 30일 예측에선 무시 가능. 호환성 우선해 default 를 fp32 로.
|
||||||
|
# - 드라이버/torch 업그레이드 후 다시 bf16 시험하려면 .env 에
|
||||||
|
# CHRONOS_DTYPE=bf16 (또는 fp16) 두면 됨.
|
||||||
|
dtype_pref = os.environ.get("CHRONOS_DTYPE", "fp32").lower()
|
||||||
|
if device == "cuda" and dtype_pref == "bf16":
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
elif device == "cuda" and dtype_pref == "fp16":
|
||||||
|
dtype = torch.float16
|
||||||
|
else:
|
||||||
|
dtype = torch.float32
|
||||||
|
logger.info("loading Chronos %s on %s (dtype=%s)", MODEL_NAME, device, dtype)
|
||||||
|
pipe = ChronosPipeline.from_pretrained(
|
||||||
|
MODEL_NAME,
|
||||||
|
device_map=device,
|
||||||
|
torch_dtype=dtype,
|
||||||
|
)
|
||||||
|
_state.update({"loaded": True, "pipe": pipe, "device": device})
|
||||||
|
|
||||||
|
|
||||||
|
def _reload_cpu() -> None:
|
||||||
|
"""현재 pipeline 을 폐기하고 CPU 로 강제 재로드.
|
||||||
|
|
||||||
|
cuda 환경에서 'no kernel image is available for execution on the device' 같이
|
||||||
|
런타임에야 드러나는 GPU 비호환 에러가 났을 때 자동 폴백용. 한 번 폴백하면
|
||||||
|
다음 호출부터는 CPU 그대로 사용 (재시도 비용 회피)."""
|
||||||
|
global _state
|
||||||
|
import torch
|
||||||
|
from chronos import ChronosPipeline
|
||||||
|
with _lock:
|
||||||
|
logger.warning("falling back to CPU for Chronos (GPU inference failed)")
|
||||||
|
_state.update({"loaded": False, "pipe": None, "device": None})
|
||||||
|
pipe = ChronosPipeline.from_pretrained(
|
||||||
|
MODEL_NAME,
|
||||||
|
device_map="cpu",
|
||||||
|
torch_dtype=torch.float32,
|
||||||
|
)
|
||||||
|
_state.update({"loaded": True, "pipe": pipe, "device": "cpu"})
|
||||||
|
|
||||||
|
|
||||||
|
def forecast(
|
||||||
|
series: list[float],
|
||||||
|
*,
|
||||||
|
horizon: int = 5,
|
||||||
|
num_samples: int = 30,
|
||||||
|
) -> ChronosForecast:
|
||||||
|
"""series 의 마지막 시점 이후 horizon 일 예측.
|
||||||
|
|
||||||
|
series 는 일봉 종가. 최소 32개 권장 (그보다 짧으면 Chronos 분위 안정성 떨어짐).
|
||||||
|
"""
|
||||||
|
if len(series) < 32:
|
||||||
|
raise ValueError(
|
||||||
|
f"series too short ({len(series)}) for Chronos forecast (need >=32)"
|
||||||
|
)
|
||||||
|
_load()
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def _do_predict():
|
||||||
|
pipe = _state["pipe"]
|
||||||
|
context = torch.tensor([float(x) for x in series], dtype=torch.float32)
|
||||||
|
with torch.no_grad():
|
||||||
|
return pipe.predict(
|
||||||
|
context=context,
|
||||||
|
prediction_length=horizon,
|
||||||
|
num_samples=num_samples,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
samples = _do_predict()
|
||||||
|
except RuntimeError as exc:
|
||||||
|
# cuda 빌드/드라이버 미스매치는 inference 시점에야 드러나는 경우가 많음.
|
||||||
|
# 'no kernel image is available' / 'CUDA error' 같은 신호 잡아서 CPU 로 폴백.
|
||||||
|
msg = str(exc)
|
||||||
|
if _state.get("device") == "cuda" and (
|
||||||
|
"no kernel image" in msg
|
||||||
|
or "CUDA error" in msg
|
||||||
|
or "CUBLAS" in msg
|
||||||
|
):
|
||||||
|
_reload_cpu()
|
||||||
|
samples = _do_predict()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
# samples: (1, num_samples, prediction_length)
|
||||||
|
arr = samples[0].cpu().float().numpy()
|
||||||
|
q10 = np.quantile(arr, 0.10, axis=0).tolist()
|
||||||
|
median = np.quantile(arr, 0.50, axis=0).tolist()
|
||||||
|
q90 = np.quantile(arr, 0.90, axis=0).tolist()
|
||||||
|
return ChronosForecast(
|
||||||
|
horizon=horizon,
|
||||||
|
median=[float(x) for x in median],
|
||||||
|
q10=[float(x) for x in q10],
|
||||||
|
q90=[float(x) for x in q90],
|
||||||
|
samples=[[float(x) for x in row] for row in arr.tolist()],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def ping() -> dict[str, object]:
|
||||||
|
try:
|
||||||
|
_load()
|
||||||
|
return {"status": "ok", "model": MODEL_NAME, "device": _state["device"]}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return {"status": "failed", "model": MODEL_NAME, "error": str(exc)}
|
||||||
193
backend/app/models/ensemble.py
Normal file
193
backend/app/models/ensemble.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
"""Chronos + LGBM 앙상블 추론.
|
||||||
|
|
||||||
|
final_price[h] = w_c * chronos.median[h-1] + w_l * lgbm.predicted_close
|
||||||
|
final_q10[h] = w_c * chronos.q10[h-1] + w_l * lgbm.predicted_close * 0.97
|
||||||
|
final_q90[h] = w_c * chronos.q90[h-1] + w_l * lgbm.predicted_close * 1.03
|
||||||
|
|
||||||
|
LGBM 은 단일 horizon 의 다음 종가(point) 만 주므로, 그 자체로는 신뢰구간이 없음.
|
||||||
|
근사로 ±3% band 를 LGBM 의 q10/q90 자리에 사용. Chronos 의 sample 분포가
|
||||||
|
주된 신뢰구간 정보 (Chronos 우세하면 ci 가 좁아짐).
|
||||||
|
|
||||||
|
direction 확률:
|
||||||
|
- LGBM 분류기에서 prob_up/flat/down (3-class) 그대로
|
||||||
|
- Chronos 는 next-day return 부호 비율: samples.shift1 / base_close - 1 의 부호
|
||||||
|
- 둘을 같은 가중치로 평균
|
||||||
|
|
||||||
|
LGBM 모델이 없으면 Chronos 단독으로 진행 (cold start).
|
||||||
|
Chronos 도 실패하면 LGBM 단독으로 진행.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.models.chronos import ChronosForecast
|
||||||
|
from app.models.chronos import forecast as chronos_forecast
|
||||||
|
from app.models.lgbm import LgbmForecast
|
||||||
|
from app.models.lgbm import predict_one as lgbm_predict
|
||||||
|
from app.models.weights import load_weights
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnsembleStep:
|
||||||
|
horizon: int # 1..H 거래일 후
|
||||||
|
target_idx: int # chronos median 의 0-based 인덱스 (horizon-1)
|
||||||
|
point_close: float
|
||||||
|
ci_low: float
|
||||||
|
ci_high: float
|
||||||
|
prob_up: float
|
||||||
|
prob_flat: float
|
||||||
|
prob_down: float
|
||||||
|
direction: str # 'up' / 'flat' / 'down'
|
||||||
|
expected_return: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnsemblePrediction:
|
||||||
|
code: str
|
||||||
|
base_close: float
|
||||||
|
horizons: list[int]
|
||||||
|
steps: list[EnsembleStep]
|
||||||
|
sources_used: list[str]
|
||||||
|
# shadow 저장용 원본 출력 (predict_one.py 가 ensemble + chronos 단독 + lgbm 단독
|
||||||
|
# 3 종을 predictions 에 적재해서 retrain_weekly 가 모델별 hit_rate 비교 가능하게 함).
|
||||||
|
chronos_raw: ChronosForecast | None = None
|
||||||
|
lgbm_raw: dict[int, LgbmForecast] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def _chronos_direction(samples: list[list[float]], base_close: float, horizon: int) -> tuple[float, float, float]:
|
||||||
|
"""Chronos sample 분포에서 (prob_up, prob_flat, prob_down). ±0.3% flat band."""
|
||||||
|
if not samples:
|
||||||
|
return 0.33, 0.34, 0.33
|
||||||
|
arr = np.array(samples)[:, horizon - 1] # 해당 step 의 sample 값
|
||||||
|
ret = arr / base_close - 1.0
|
||||||
|
p_up = float((ret > 0.003).mean())
|
||||||
|
p_dn = float((ret < -0.003).mean())
|
||||||
|
p_fl = 1.0 - p_up - p_dn
|
||||||
|
return p_up, p_fl, p_dn
|
||||||
|
|
||||||
|
|
||||||
|
def predict(code: str, *, horizons: tuple[int, ...] = (1, 3, 5)) -> EnsemblePrediction:
|
||||||
|
"""한 종목에 대해 horizons 별 앙상블 예측. on-demand 추론용."""
|
||||||
|
max_h = max(horizons)
|
||||||
|
|
||||||
|
# Chronos: 종가 시계열 가져와서 max_h 까지 예측.
|
||||||
|
from app.models.features import build_features # local import
|
||||||
|
|
||||||
|
ff = build_features(code, lookback_days=400, horizons=horizons, with_targets=False)
|
||||||
|
df = ff.df
|
||||||
|
if df.empty:
|
||||||
|
raise RuntimeError(f"no OHLCV data for {code}")
|
||||||
|
closes = df["close"].astype(float).tolist()
|
||||||
|
base_close = float(closes[-1])
|
||||||
|
|
||||||
|
sources_used: list[str] = []
|
||||||
|
cf: ChronosForecast | None = None
|
||||||
|
chronos_err: str | None = None
|
||||||
|
try:
|
||||||
|
cf = chronos_forecast(closes, horizon=max_h, num_samples=30)
|
||||||
|
sources_used.append("chronos")
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
chronos_err = f"{type(exc).__name__}: {exc}"
|
||||||
|
logger.warning("chronos forecast failed for %s: %s", code, chronos_err)
|
||||||
|
|
||||||
|
steps: list[EnsembleStep] = []
|
||||||
|
lgbm_raw: dict[int, LgbmForecast] = {}
|
||||||
|
for h in horizons:
|
||||||
|
lf: LgbmForecast | None = None
|
||||||
|
lgbm_err: str | None = None
|
||||||
|
try:
|
||||||
|
lf = lgbm_predict(code, h)
|
||||||
|
if lf is not None:
|
||||||
|
sources_used.append(f"lgbm_h{h}")
|
||||||
|
lgbm_raw[h] = lf
|
||||||
|
else:
|
||||||
|
# predict_one 이 None 반환 = 체크포인트 파일 없음 (cold start).
|
||||||
|
lgbm_err = "model checkpoint not found (run retrain_weekly)"
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
lgbm_err = f"{type(exc).__name__}: {exc}"
|
||||||
|
logger.warning("lgbm predict failed for %s h=%d: %s", code, h, lgbm_err)
|
||||||
|
|
||||||
|
# 가중치 (DB 없으면 default 0.6/0.4).
|
||||||
|
w = load_weights(code, h)
|
||||||
|
wc, wl = w.w_chronos, w.w_lgbm
|
||||||
|
# 한쪽이 없으면 다른 쪽 전부.
|
||||||
|
if cf is None and lf is None:
|
||||||
|
# 사용자가 브라우저에서 바로 원인을 보게 두 에러를 그대로 노출.
|
||||||
|
raise RuntimeError(
|
||||||
|
f"both chronos & lgbm failed for {code} h={h}; "
|
||||||
|
f"chronos={chronos_err or 'unknown'}; lgbm={lgbm_err or 'unknown'}"
|
||||||
|
)
|
||||||
|
if cf is None:
|
||||||
|
wc, wl = 0.0, 1.0
|
||||||
|
if lf is None:
|
||||||
|
wc, wl = 1.0, 0.0
|
||||||
|
|
||||||
|
if cf is not None:
|
||||||
|
c_med = cf.median[h - 1]
|
||||||
|
c_q10 = cf.q10[h - 1]
|
||||||
|
c_q90 = cf.q90[h - 1]
|
||||||
|
else:
|
||||||
|
c_med = c_q10 = c_q90 = base_close # not used (wc=0)
|
||||||
|
|
||||||
|
if lf is not None:
|
||||||
|
l_close = lf.predicted_close
|
||||||
|
l_lo = l_close * 0.97
|
||||||
|
l_hi = l_close * 1.03
|
||||||
|
l_pu, l_pf, l_pd = lf.prob_up, lf.prob_flat, lf.prob_down
|
||||||
|
else:
|
||||||
|
l_close = l_lo = l_hi = base_close
|
||||||
|
l_pu = l_pf = l_pd = 0.0
|
||||||
|
|
||||||
|
point = wc * c_med + wl * l_close
|
||||||
|
lo = wc * c_q10 + wl * l_lo
|
||||||
|
hi = wc * c_q90 + wl * l_hi
|
||||||
|
|
||||||
|
if cf is not None:
|
||||||
|
cp_up, cp_fl, cp_dn = _chronos_direction(cf.samples, base_close, h)
|
||||||
|
else:
|
||||||
|
cp_up = cp_fl = cp_dn = 0.0
|
||||||
|
|
||||||
|
# direction prob: source 마다 weights 동일하게 가중평균
|
||||||
|
if lf is not None and cf is not None:
|
||||||
|
p_up = 0.5 * cp_up + 0.5 * l_pu
|
||||||
|
p_fl = 0.5 * cp_fl + 0.5 * l_pf
|
||||||
|
p_dn = 0.5 * cp_dn + 0.5 * l_pd
|
||||||
|
elif cf is not None:
|
||||||
|
p_up, p_fl, p_dn = cp_up, cp_fl, cp_dn
|
||||||
|
else:
|
||||||
|
p_up, p_fl, p_dn = l_pu, l_pf, l_pd
|
||||||
|
|
||||||
|
# 정규화 (혹시 합이 0 가 아닐 때)
|
||||||
|
s = max(p_up + p_fl + p_dn, 1e-9)
|
||||||
|
p_up, p_fl, p_dn = p_up / s, p_fl / s, p_dn / s
|
||||||
|
dir_lbl = "up" if p_up >= max(p_fl, p_dn) else ("down" if p_dn >= p_fl else "flat")
|
||||||
|
|
||||||
|
steps.append(
|
||||||
|
EnsembleStep(
|
||||||
|
horizon=h,
|
||||||
|
target_idx=h - 1,
|
||||||
|
point_close=float(point),
|
||||||
|
ci_low=float(lo),
|
||||||
|
ci_high=float(hi),
|
||||||
|
prob_up=float(p_up),
|
||||||
|
prob_flat=float(p_fl),
|
||||||
|
prob_down=float(p_dn),
|
||||||
|
direction=dir_lbl,
|
||||||
|
expected_return=float(point / base_close - 1.0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return EnsemblePrediction(
|
||||||
|
code=code,
|
||||||
|
base_close=base_close,
|
||||||
|
horizons=list(horizons),
|
||||||
|
steps=steps,
|
||||||
|
sources_used=sources_used,
|
||||||
|
chronos_raw=cf,
|
||||||
|
lgbm_raw=lgbm_raw,
|
||||||
|
)
|
||||||
223
backend/app/models/features.py
Normal file
223
backend/app/models/features.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"""모델 학습/추론용 피처 빌더.
|
||||||
|
|
||||||
|
종목 1개 + 룩백 기간을 받아 (date 단위) DataFrame 반환:
|
||||||
|
- OHLCV
|
||||||
|
- returns r1
|
||||||
|
- TA: rsi14, macd, macd_signal, atr14, bb_pct, sma20, ema12, vol_z20
|
||||||
|
- trading_value: foreign_net, institution_net, individual_net (정규화 X, scale 그대로)
|
||||||
|
- macro 정렬: kospi, kosdaq, usdkrw, us10y, kospi_r1, usdkrw_r1
|
||||||
|
- sentiment (v_sentiment_daily): mean_score, weighted_score, n_articles,
|
||||||
|
pos_minus_neg = pos_ratio - neg_ratio. 3일 롤링 mean 도 추가.
|
||||||
|
|
||||||
|
학습 타깃 (build_features 에서만 생성):
|
||||||
|
- y_close_h{1,3,5}: close.shift(-H)
|
||||||
|
- y_ret_h{1,3,5}: y_close_h / close - 1
|
||||||
|
- y_dir_h{1,3,5}: sign(y_ret_h) (1=up, -1=down, 0=flat ±0.3% 이내)
|
||||||
|
|
||||||
|
inference 용 build_features 는 dropna 안 함. 학습용 build_training_frame 은 dropna.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date, timedelta
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
FLAT_BAND = 0.003 # ±0.3% 이내는 flat
|
||||||
|
HORIZONS_DEFAULT = (1, 3, 5)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureFrame:
|
||||||
|
code: str
|
||||||
|
df: pd.DataFrame
|
||||||
|
target_horizons: tuple[int, ...]
|
||||||
|
|
||||||
|
|
||||||
|
def _load_ohlcv(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, open, high, low, close, volume
|
||||||
|
FROM ohlcv_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date", "open", "high", "low", "close", "volume"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "open", "high", "low", "close", "volume"])
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _load_trading(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, foreign_net, institution_net, individual_net
|
||||||
|
FROM trading_value_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date", "foreign_net", "institution_net", "individual_net"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "foreign_net", "institution_net", "individual_net"])
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _load_macro(start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"SELECT date, key, value FROM macro_daily "
|
||||||
|
"WHERE date BETWEEN :s AND :e ORDER BY date"
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"s": start, "e": end}).all()
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=["date"])
|
||||||
|
df = pd.DataFrame(rows, columns=["date", "key", "value"])
|
||||||
|
pivot = df.pivot_table(index="date", columns="key", values="value", aggfunc="last").reset_index()
|
||||||
|
pivot["date"] = pd.to_datetime(pivot["date"]).dt.date
|
||||||
|
pivot.columns.name = None
|
||||||
|
return pivot
|
||||||
|
|
||||||
|
|
||||||
|
def _load_sentiment(code: str, start: date, end: date) -> pd.DataFrame:
|
||||||
|
eng = get_engine()
|
||||||
|
sql = text(
|
||||||
|
"""
|
||||||
|
SELECT date, n_articles, mean_score, pos_ratio, neg_ratio,
|
||||||
|
weighted_score
|
||||||
|
FROM v_sentiment_daily
|
||||||
|
WHERE code = :code AND date BETWEEN :s AND :e
|
||||||
|
ORDER BY date
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(sql, {"code": code, "s": start, "e": end}).all()
|
||||||
|
cols = ["date", "n_articles", "mean_score", "pos_ratio", "neg_ratio", "weighted_score"]
|
||||||
|
if not rows:
|
||||||
|
return pd.DataFrame(columns=cols)
|
||||||
|
df = pd.DataFrame(rows, columns=cols)
|
||||||
|
df["date"] = pd.to_datetime(df["date"]).dt.date
|
||||||
|
df["pos_minus_neg"] = df["pos_ratio"].fillna(0) - df["neg_ratio"].fillna(0)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _add_ta(df: pd.DataFrame) -> pd.DataFrame:
|
||||||
|
"""ta 패키지로 기술 지표 추가."""
|
||||||
|
from ta.momentum import RSIIndicator
|
||||||
|
from ta.trend import EMAIndicator, MACD, SMAIndicator
|
||||||
|
from ta.volatility import AverageTrueRange, BollingerBands
|
||||||
|
|
||||||
|
close = df["close"].astype(float)
|
||||||
|
high = df["high"].astype(float)
|
||||||
|
low = df["low"].astype(float)
|
||||||
|
vol = df["volume"].astype(float)
|
||||||
|
|
||||||
|
df["r1"] = close.pct_change()
|
||||||
|
df["rsi14"] = RSIIndicator(close=close, window=14, fillna=False).rsi()
|
||||||
|
macd = MACD(close=close, window_slow=26, window_fast=12, window_sign=9, fillna=False)
|
||||||
|
df["macd"] = macd.macd()
|
||||||
|
df["macd_signal"] = macd.macd_signal()
|
||||||
|
df["atr14"] = AverageTrueRange(high=high, low=low, close=close, window=14, fillna=False).average_true_range()
|
||||||
|
bb = BollingerBands(close=close, window=20, window_dev=2, fillna=False)
|
||||||
|
df["bb_pct"] = bb.bollinger_pband()
|
||||||
|
df["sma20"] = SMAIndicator(close=close, window=20, fillna=False).sma_indicator()
|
||||||
|
df["ema12"] = EMAIndicator(close=close, window=12, fillna=False).ema_indicator()
|
||||||
|
vol_mean = vol.rolling(20).mean()
|
||||||
|
vol_std = vol.rolling(20).std().replace(0, np.nan)
|
||||||
|
df["vol_z20"] = (vol - vol_mean) / vol_std
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def _add_targets(df: pd.DataFrame, horizons: tuple[int, ...]) -> pd.DataFrame:
|
||||||
|
close = df["close"].astype(float)
|
||||||
|
for h in horizons:
|
||||||
|
df[f"y_close_h{h}"] = close.shift(-h)
|
||||||
|
df[f"y_ret_h{h}"] = df[f"y_close_h{h}"] / close - 1.0
|
||||||
|
df[f"y_dir_h{h}"] = np.where(
|
||||||
|
df[f"y_ret_h{h}"] > FLAT_BAND, 1,
|
||||||
|
np.where(df[f"y_ret_h{h}"] < -FLAT_BAND, -1, 0),
|
||||||
|
)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def build_features(
|
||||||
|
code: str,
|
||||||
|
*,
|
||||||
|
lookback_days: int = 365 * 2,
|
||||||
|
end_date: date | None = None,
|
||||||
|
horizons: tuple[int, ...] = HORIZONS_DEFAULT,
|
||||||
|
with_targets: bool = False,
|
||||||
|
) -> FeatureFrame:
|
||||||
|
"""code 1개 종목의 피처 DataFrame 생성.
|
||||||
|
|
||||||
|
inference: with_targets=False 로 호출 → 최신 row 의 피처만 LGBM/Chronos 에 투입.
|
||||||
|
training : with_targets=True 로 호출 → tail H 행은 타깃 NaN → dropna 로 제거.
|
||||||
|
"""
|
||||||
|
end = end_date or date.today()
|
||||||
|
start = end - timedelta(days=lookback_days)
|
||||||
|
|
||||||
|
ohlcv = _load_ohlcv(code, start, end)
|
||||||
|
if ohlcv.empty:
|
||||||
|
return FeatureFrame(code=code, df=ohlcv, target_horizons=horizons)
|
||||||
|
|
||||||
|
df = ohlcv.copy().sort_values("date").reset_index(drop=True)
|
||||||
|
|
||||||
|
df = _add_ta(df)
|
||||||
|
|
||||||
|
trading = _load_trading(code, start, end)
|
||||||
|
if not trading.empty:
|
||||||
|
df = df.merge(trading, on="date", how="left")
|
||||||
|
else:
|
||||||
|
for col in ("foreign_net", "institution_net", "individual_net"):
|
||||||
|
df[col] = np.nan
|
||||||
|
|
||||||
|
macro = _load_macro(start, end)
|
||||||
|
if not macro.empty:
|
||||||
|
df = df.merge(macro, on="date", how="left")
|
||||||
|
for k in ("kospi", "kosdaq", "usdkrw", "us10y"):
|
||||||
|
if k in df.columns:
|
||||||
|
df[f"{k}_r1"] = df[k].pct_change()
|
||||||
|
|
||||||
|
sentiment = _load_sentiment(code, start, end)
|
||||||
|
if not sentiment.empty:
|
||||||
|
df = df.merge(sentiment, on="date", how="left")
|
||||||
|
# 3일 롤링 평균
|
||||||
|
for col in ("mean_score", "weighted_score", "pos_minus_neg", "n_articles"):
|
||||||
|
if col in df.columns:
|
||||||
|
df[f"{col}_3d"] = df[col].rolling(3, min_periods=1).mean()
|
||||||
|
else:
|
||||||
|
for col in ("n_articles", "mean_score", "pos_ratio", "neg_ratio",
|
||||||
|
"weighted_score", "pos_minus_neg"):
|
||||||
|
df[col] = np.nan
|
||||||
|
|
||||||
|
if with_targets:
|
||||||
|
df = _add_targets(df, horizons)
|
||||||
|
|
||||||
|
return FeatureFrame(code=code, df=df, target_horizons=horizons)
|
||||||
|
|
||||||
|
|
||||||
|
def feature_columns(df: pd.DataFrame) -> list[str]:
|
||||||
|
"""LGBM 학습/추론용 피처 컬럼 목록. date / OHLCV / y_* 제외."""
|
||||||
|
drop = {"date", "open", "high", "low", "close", "volume"}
|
||||||
|
cols = [
|
||||||
|
c for c in df.columns
|
||||||
|
if c not in drop and not c.startswith("y_")
|
||||||
|
]
|
||||||
|
return cols
|
||||||
180
backend/app/models/lgbm.py
Normal file
180
backend/app/models/lgbm.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""LightGBM 회귀 + 분류 모델. 종목 × horizon 별 별도 저장.
|
||||||
|
|
||||||
|
- 회귀: target = y_ret_h{H}. 예측 후 base_close*(1+pred) 로 가격 환산.
|
||||||
|
- 분류: target = y_dir_h{H} ∈ {-1, 0, +1}. 3-class softmax 로 prob_up/flat/down.
|
||||||
|
|
||||||
|
저장 경로: backend/data/models/{code}_h{H}_reg.pkl, _cls.pkl (joblib).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
from app.models.features import build_features, feature_columns
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_DIR = Path(os.environ.get("LGBM_MODEL_DIR", "/app/data/models"))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LgbmForecast:
|
||||||
|
horizon: int
|
||||||
|
base_close: float
|
||||||
|
predicted_close: float
|
||||||
|
predicted_return: float
|
||||||
|
prob_up: float
|
||||||
|
prob_flat: float
|
||||||
|
prob_down: float
|
||||||
|
|
||||||
|
|
||||||
|
def _model_paths(code: str, horizon: int) -> tuple[Path, Path]:
|
||||||
|
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
return (
|
||||||
|
MODEL_DIR / f"{code}_h{horizon}_reg.pkl",
|
||||||
|
MODEL_DIR / f"{code}_h{horizon}_cls.pkl",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prepare_xy(code: str, horizon: int, lookback_days: int) -> tuple[pd.DataFrame, pd.Series, pd.Series, list[str]]:
|
||||||
|
ff = build_features(
|
||||||
|
code,
|
||||||
|
lookback_days=lookback_days,
|
||||||
|
horizons=(horizon,),
|
||||||
|
with_targets=True,
|
||||||
|
)
|
||||||
|
df = ff.df
|
||||||
|
if df.empty:
|
||||||
|
return df, pd.Series(dtype=float), pd.Series(dtype=int), []
|
||||||
|
y_ret_col = f"y_ret_h{horizon}"
|
||||||
|
y_dir_col = f"y_dir_h{horizon}"
|
||||||
|
# 타깃 NaN (마지막 H 행) 제거.
|
||||||
|
df = df.dropna(subset=[y_ret_col, y_dir_col])
|
||||||
|
feats = feature_columns(df)
|
||||||
|
if not feats:
|
||||||
|
return df, pd.Series(dtype=float), pd.Series(dtype=int), []
|
||||||
|
X = df[feats]
|
||||||
|
# LightGBM 은 NaN 자체 처리 가능.
|
||||||
|
y_ret = df[y_ret_col].astype(float)
|
||||||
|
y_dir = df[y_dir_col].astype(int)
|
||||||
|
return X, y_ret, y_dir, feats
|
||||||
|
|
||||||
|
|
||||||
|
def train_one(code: str, horizon: int, *, lookback_days: int = 365 * 3) -> dict:
|
||||||
|
"""1종목 × 1 horizon 학습. 저장된 모델 파일 경로 + 샘플 수 반환."""
|
||||||
|
import lightgbm as lgb
|
||||||
|
|
||||||
|
X, y_ret, y_dir, feats = _prepare_xy(code, horizon, lookback_days)
|
||||||
|
if X.empty or len(X) < 100:
|
||||||
|
return {"code": code, "horizon": horizon, "status": "skipped_too_few_rows", "n_rows": int(len(X))}
|
||||||
|
|
||||||
|
reg_params = dict(
|
||||||
|
objective="regression",
|
||||||
|
learning_rate=0.05,
|
||||||
|
num_leaves=31,
|
||||||
|
min_data_in_leaf=20,
|
||||||
|
feature_fraction=0.85,
|
||||||
|
bagging_fraction=0.8,
|
||||||
|
bagging_freq=5,
|
||||||
|
verbose=-1,
|
||||||
|
)
|
||||||
|
cls_params = dict(
|
||||||
|
objective="multiclass",
|
||||||
|
num_class=3,
|
||||||
|
learning_rate=0.05,
|
||||||
|
num_leaves=31,
|
||||||
|
min_data_in_leaf=20,
|
||||||
|
feature_fraction=0.85,
|
||||||
|
bagging_fraction=0.8,
|
||||||
|
bagging_freq=5,
|
||||||
|
verbose=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 분류는 -1/0/1 → 0/1/2 인덱스로 매핑.
|
||||||
|
y_dir_idx = (y_dir + 1).astype(int)
|
||||||
|
|
||||||
|
n = len(X)
|
||||||
|
split = int(n * 0.85)
|
||||||
|
X_tr, X_val = X.iloc[:split], X.iloc[split:]
|
||||||
|
yr_tr, yr_val = y_ret.iloc[:split], y_ret.iloc[split:]
|
||||||
|
yc_tr, yc_val = y_dir_idx.iloc[:split], y_dir_idx.iloc[split:]
|
||||||
|
|
||||||
|
reg_train = lgb.Dataset(X_tr, label=yr_tr)
|
||||||
|
reg_valid = lgb.Dataset(X_val, label=yr_val, reference=reg_train)
|
||||||
|
reg_model = lgb.train(
|
||||||
|
reg_params,
|
||||||
|
reg_train,
|
||||||
|
num_boost_round=400,
|
||||||
|
valid_sets=[reg_valid],
|
||||||
|
callbacks=[lgb.early_stopping(stopping_rounds=30, verbose=False)],
|
||||||
|
)
|
||||||
|
|
||||||
|
cls_train = lgb.Dataset(X_tr, label=yc_tr)
|
||||||
|
cls_valid = lgb.Dataset(X_val, label=yc_val, reference=cls_train)
|
||||||
|
cls_model = lgb.train(
|
||||||
|
cls_params,
|
||||||
|
cls_train,
|
||||||
|
num_boost_round=400,
|
||||||
|
valid_sets=[cls_valid],
|
||||||
|
callbacks=[lgb.early_stopping(stopping_rounds=30, verbose=False)],
|
||||||
|
)
|
||||||
|
|
||||||
|
reg_path, cls_path = _model_paths(code, horizon)
|
||||||
|
joblib.dump({"model": reg_model, "features": feats}, reg_path)
|
||||||
|
joblib.dump({"model": cls_model, "features": feats}, cls_path)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": code,
|
||||||
|
"horizon": horizon,
|
||||||
|
"status": "ok",
|
||||||
|
"n_rows": int(len(X)),
|
||||||
|
"reg_best_iter": int(reg_model.best_iteration or 0),
|
||||||
|
"cls_best_iter": int(cls_model.best_iteration or 0),
|
||||||
|
"reg_path": str(reg_path),
|
||||||
|
"cls_path": str(cls_path),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def predict_one(code: str, horizon: int, *, lookback_days: int = 400) -> LgbmForecast | None:
|
||||||
|
"""1종목 × 1 horizon 추론. 모델 없으면 None.
|
||||||
|
|
||||||
|
가장 최신 영업일 피처를 사용. base_close 는 그 행의 close.
|
||||||
|
"""
|
||||||
|
reg_path, cls_path = _model_paths(code, horizon)
|
||||||
|
if not reg_path.exists() or not cls_path.exists():
|
||||||
|
return None
|
||||||
|
reg_blob = joblib.load(reg_path)
|
||||||
|
cls_blob = joblib.load(cls_path)
|
||||||
|
feats_reg = reg_blob["features"]
|
||||||
|
feats_cls = cls_blob["features"]
|
||||||
|
reg_model = reg_blob["model"]
|
||||||
|
cls_model = cls_blob["model"]
|
||||||
|
|
||||||
|
ff = build_features(code, lookback_days=lookback_days, horizons=(horizon,), with_targets=False)
|
||||||
|
df = ff.df
|
||||||
|
if df.empty:
|
||||||
|
return None
|
||||||
|
last = df.iloc[[-1]]
|
||||||
|
base_close = float(last["close"].iloc[0])
|
||||||
|
# 피처 정렬 (모델이 학습 당시 본 컬럼 순서대로).
|
||||||
|
X_reg = last.reindex(columns=feats_reg).fillna(value=np.nan)
|
||||||
|
X_cls = last.reindex(columns=feats_cls).fillna(value=np.nan)
|
||||||
|
pred_ret = float(reg_model.predict(X_reg)[0])
|
||||||
|
probs = cls_model.predict(X_cls)[0]
|
||||||
|
# 인덱스 0=-1(down), 1=0(flat), 2=+1(up)
|
||||||
|
prob_down, prob_flat, prob_up = float(probs[0]), float(probs[1]), float(probs[2])
|
||||||
|
return LgbmForecast(
|
||||||
|
horizon=horizon,
|
||||||
|
base_close=base_close,
|
||||||
|
predicted_close=base_close * (1.0 + pred_ret),
|
||||||
|
predicted_return=pred_ret,
|
||||||
|
prob_up=prob_up,
|
||||||
|
prob_flat=prob_flat,
|
||||||
|
prob_down=prob_down,
|
||||||
|
)
|
||||||
75
backend/app/models/weights.py
Normal file
75
backend/app/models/weights.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
"""ensemble_weights 테이블 IO. 기본 가중치 (chronos 0.6, lgbm 0.4)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EnsembleWeights:
|
||||||
|
code: str
|
||||||
|
horizon: int
|
||||||
|
w_chronos: float
|
||||||
|
w_lgbm: float
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_W_CHRONOS = 0.6
|
||||||
|
DEFAULT_W_LGBM = 0.4
|
||||||
|
|
||||||
|
|
||||||
|
def load_weights(code: str, horizon: int) -> EnsembleWeights:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
text(
|
||||||
|
"SELECT w_chronos, w_lgbm FROM ensemble_weights "
|
||||||
|
"WHERE code = :code AND horizon = :h"
|
||||||
|
),
|
||||||
|
{"code": code, "h": horizon},
|
||||||
|
).first()
|
||||||
|
if not row:
|
||||||
|
return EnsembleWeights(code, horizon, DEFAULT_W_CHRONOS, DEFAULT_W_LGBM)
|
||||||
|
return EnsembleWeights(code, horizon, float(row[0]), float(row[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def upsert_weights(
|
||||||
|
code: str,
|
||||||
|
horizon: int,
|
||||||
|
w_chronos: float,
|
||||||
|
w_lgbm: float,
|
||||||
|
*,
|
||||||
|
hit_rate_chronos: float | None = None,
|
||||||
|
hit_rate_lgbm: float | None = None,
|
||||||
|
sample_count: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.begin() as conn:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO ensemble_weights
|
||||||
|
(code, horizon, w_chronos, w_lgbm, hit_rate_chronos, hit_rate_lgbm, sample_count, updated_at)
|
||||||
|
VALUES
|
||||||
|
(:code, :h, :wc, :wl, :hc, :hl, :n, NOW())
|
||||||
|
ON CONFLICT (code, horizon) DO UPDATE SET
|
||||||
|
w_chronos = EXCLUDED.w_chronos,
|
||||||
|
w_lgbm = EXCLUDED.w_lgbm,
|
||||||
|
hit_rate_chronos = EXCLUDED.hit_rate_chronos,
|
||||||
|
hit_rate_lgbm = EXCLUDED.hit_rate_lgbm,
|
||||||
|
sample_count = EXCLUDED.sample_count,
|
||||||
|
updated_at = NOW()
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"code": code,
|
||||||
|
"h": horizon,
|
||||||
|
"wc": float(w_chronos),
|
||||||
|
"wl": float(w_lgbm),
|
||||||
|
"hc": hit_rate_chronos,
|
||||||
|
"hl": hit_rate_lgbm,
|
||||||
|
"n": sample_count,
|
||||||
|
},
|
||||||
|
)
|
||||||
0
backend/app/nlp/__init__.py
Normal file
0
backend/app/nlp/__init__.py
Normal file
150
backend/app/nlp/finbert.py
Normal file
150
backend/app/nlp/finbert.py
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
"""KR-FinBERT 감성 분석 어댑터.
|
||||||
|
|
||||||
|
모델: snunlp/KR-FinBert-SC (3-class: negative / neutral / positive)
|
||||||
|
|
||||||
|
- score : prob(positive) - prob(negative) ∈ [-1, +1]
|
||||||
|
- label : argmax 결과 ('positive' / 'neutral' / 'negative')
|
||||||
|
- embedding : 마지막 hidden state mean pool (768d) — `news.embedding` (VECTOR) 저장용
|
||||||
|
|
||||||
|
디바이스: settings.model_device ('auto' → cuda 가용 시 cuda, 아니면 cpu).
|
||||||
|
인증: settings.huggingface_token (gated 모델은 아니지만 HF rate limit 우회 + 일관성).
|
||||||
|
캐시: HF_HOME=/root/.cache/huggingface (docker-compose 의 `hf_cache` 볼륨).
|
||||||
|
|
||||||
|
lazy singleton — FastAPI 기동 시점에 모델을 로드하지 않고, 첫 score_texts() 호출
|
||||||
|
또는 ping() 호출 시점에 로드.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MODEL_NAME = "snunlp/KR-FinBert-SC"
|
||||||
|
# KR-FinBert-SC 의 id2label : {0: 'negative', 1: 'neutral', 2: 'positive'}
|
||||||
|
_LABELS = ("negative", "neutral", "positive")
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_state: dict[str, object] = {
|
||||||
|
"loaded": False,
|
||||||
|
"tokenizer": None,
|
||||||
|
"model": None,
|
||||||
|
"device": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FinbertOutput:
|
||||||
|
label: str
|
||||||
|
score: float # prob_positive - prob_negative ∈ [-1, +1]
|
||||||
|
prob_negative: float
|
||||||
|
prob_neutral: float
|
||||||
|
prob_positive: float
|
||||||
|
embedding: list[float] # 768d mean-pooled last hidden state
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_device() -> str:
|
||||||
|
"""settings.model_device 값에 따라 'cuda'/'cpu' 결정."""
|
||||||
|
import torch # lazy
|
||||||
|
|
||||||
|
pref = (settings.model_device or "auto").lower()
|
||||||
|
if pref == "cuda":
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
if pref == "cpu":
|
||||||
|
return "cpu"
|
||||||
|
# 'auto'
|
||||||
|
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def _load() -> None:
|
||||||
|
global _state
|
||||||
|
with _lock:
|
||||||
|
if _state["loaded"]:
|
||||||
|
return
|
||||||
|
import torch
|
||||||
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
|
|
||||||
|
token = settings.huggingface_token or None
|
||||||
|
if token:
|
||||||
|
# transformers/datasets 모두 이 env 를 인식.
|
||||||
|
os.environ.setdefault("HUGGINGFACE_HUB_TOKEN", token)
|
||||||
|
os.environ.setdefault("HF_TOKEN", token)
|
||||||
|
|
||||||
|
device = _resolve_device()
|
||||||
|
logger.info("loading %s on %s", MODEL_NAME, device)
|
||||||
|
tok = AutoTokenizer.from_pretrained(MODEL_NAME, token=token)
|
||||||
|
mdl = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
MODEL_NAME,
|
||||||
|
token=token,
|
||||||
|
output_hidden_states=True,
|
||||||
|
)
|
||||||
|
mdl.eval()
|
||||||
|
mdl.to(device)
|
||||||
|
_state.update({"loaded": True, "tokenizer": tok, "model": mdl, "device": device})
|
||||||
|
logger.info("KR-FinBERT loaded (device=%s)", device)
|
||||||
|
|
||||||
|
|
||||||
|
def score_texts(
|
||||||
|
texts: list[str],
|
||||||
|
*,
|
||||||
|
batch_size: int = 16,
|
||||||
|
max_length: int = 256,
|
||||||
|
) -> list[FinbertOutput]:
|
||||||
|
"""주어진 텍스트 리스트에 대해 감성 점수 + 라벨 + 768d embedding 반환.
|
||||||
|
|
||||||
|
빈 문자열은 placeholder('_')로 치환해서 라벨은 neutral 에 가깝게 나오게 함.
|
||||||
|
"""
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
_load()
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tok = _state["tokenizer"]
|
||||||
|
mdl = _state["model"]
|
||||||
|
device = _state["device"]
|
||||||
|
|
||||||
|
results: list[FinbertOutput] = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
chunk = [(t or "").strip() or "_" for t in texts[i : i + batch_size]]
|
||||||
|
enc = tok(
|
||||||
|
chunk,
|
||||||
|
padding=True,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_length,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(device)
|
||||||
|
out = mdl(**enc)
|
||||||
|
probs = torch.softmax(out.logits, dim=-1).cpu()
|
||||||
|
last_hidden = out.hidden_states[-1] # (B, T, H)
|
||||||
|
mask = enc["attention_mask"].unsqueeze(-1).float()
|
||||||
|
pooled = (last_hidden * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
|
||||||
|
pooled = pooled.cpu().tolist()
|
||||||
|
|
||||||
|
for row, vec in zip(probs.tolist(), pooled):
|
||||||
|
p_neg, p_neu, p_pos = row[0], row[1], row[2]
|
||||||
|
label_idx = int(max(range(3), key=lambda k: row[k]))
|
||||||
|
results.append(
|
||||||
|
FinbertOutput(
|
||||||
|
label=_LABELS[label_idx],
|
||||||
|
score=float(p_pos - p_neg),
|
||||||
|
prob_negative=float(p_neg),
|
||||||
|
prob_neutral=float(p_neu),
|
||||||
|
prob_positive=float(p_pos),
|
||||||
|
embedding=[float(x) for x in vec],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def ping() -> dict[str, object]:
|
||||||
|
"""모델 로드 가능 여부 빠르게 확인. 한 번 로드되면 캐시됨."""
|
||||||
|
try:
|
||||||
|
_load()
|
||||||
|
return {"status": "ok", "model": MODEL_NAME, "device": _state["device"]}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return {"status": "failed", "model": MODEL_NAME, "error": str(exc)}
|
||||||
96
backend/app/nlp/score_news.py
Normal file
96
backend/app/nlp/score_news.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""DB news 테이블에서 sentiment_score IS NULL 인 행을 배치로 스코어링.
|
||||||
|
|
||||||
|
refresh_one / daily_batch 에서 뉴스 upsert 직후 호출. 증분만 처리.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
from app.nlp.finbert import score_texts
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ScoreResult:
|
||||||
|
fetched: int
|
||||||
|
scored: int
|
||||||
|
failed: int
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _vector_literal(vec: list[float]) -> str:
|
||||||
|
"""pgvector 텍스트 리터럴: '[v1,v2,...]' 형식. (:e)::vector 로 캐스팅."""
|
||||||
|
return "[" + ",".join(f"{x:.6f}" for x in vec) + "]"
|
||||||
|
|
||||||
|
|
||||||
|
def score_pending_news(
|
||||||
|
*,
|
||||||
|
batch_size: int = 32,
|
||||||
|
limit: int | None = 500,
|
||||||
|
code: str | None = None,
|
||||||
|
) -> ScoreResult:
|
||||||
|
"""sentiment_score IS NULL 인 news 행에 대해 finbert score + label + embedding 채움.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: finbert inference 배치 크기.
|
||||||
|
limit: 한 번에 처리할 최대 행 수. None 이면 무제한.
|
||||||
|
daily_batch 가 너무 무겁지 않도록 기본 500.
|
||||||
|
code: 특정 종목만 (None 이면 전체 무점수 행).
|
||||||
|
"""
|
||||||
|
eng = get_engine()
|
||||||
|
where = "sentiment_score IS NULL"
|
||||||
|
params: dict[str, Any] = {}
|
||||||
|
if code:
|
||||||
|
where += " AND code = :code"
|
||||||
|
params["code"] = code
|
||||||
|
|
||||||
|
sql_select = (
|
||||||
|
"SELECT id, COALESCE(title, '') || ' ' || COALESCE(body, '') AS txt "
|
||||||
|
f"FROM news WHERE {where} ORDER BY id"
|
||||||
|
)
|
||||||
|
if limit is not None:
|
||||||
|
sql_select += f" LIMIT {int(limit)}"
|
||||||
|
|
||||||
|
with eng.connect() as conn:
|
||||||
|
rows = conn.execute(text(sql_select), params).all()
|
||||||
|
if not rows:
|
||||||
|
return ScoreResult(fetched=0, scored=0, failed=0)
|
||||||
|
|
||||||
|
ids = [r[0] for r in rows]
|
||||||
|
texts_in = [r[1] for r in rows]
|
||||||
|
|
||||||
|
try:
|
||||||
|
outputs = score_texts(texts_in, batch_size=batch_size)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.exception("score_texts failed")
|
||||||
|
return ScoreResult(fetched=len(rows), scored=0, failed=len(rows), error=str(exc))
|
||||||
|
|
||||||
|
update_sql = text(
|
||||||
|
"UPDATE news SET sentiment_score = :s, sentiment_label = :l, "
|
||||||
|
"embedding = (:e)::vector WHERE id = :id"
|
||||||
|
)
|
||||||
|
scored = 0
|
||||||
|
failed = 0
|
||||||
|
with eng.begin() as conn:
|
||||||
|
for nid, out in zip(ids, outputs):
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
update_sql,
|
||||||
|
{
|
||||||
|
"id": nid,
|
||||||
|
"s": out.score,
|
||||||
|
"l": out.label,
|
||||||
|
"e": _vector_literal(out.embedding),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
scored += 1
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("update news id=%s failed: %s", nid, exc)
|
||||||
|
failed += 1
|
||||||
|
return ScoreResult(fetched=len(rows), scored=scored, failed=failed)
|
||||||
@@ -11,6 +11,7 @@ import time
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.fetch import macro as macro_mod
|
from app.fetch import macro as macro_mod
|
||||||
|
from app.nlp.score_news import score_pending_news
|
||||||
from app.pipelines.refresh_one import refresh_code
|
from app.pipelines.refresh_one import refresh_code
|
||||||
from app.seed.seed_tickers import SEED_TICKERS
|
from app.seed.seed_tickers import SEED_TICKERS
|
||||||
|
|
||||||
@@ -32,11 +33,26 @@ def run_daily_batch() -> dict[str, Any]:
|
|||||||
for m in macros
|
for m in macros
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# 시드 종목 refresh 내에서 종목당 200건만 스코어함. 잔여(여러 소스 합쳐
|
||||||
|
# 200건 초과 또는 코드 매핑 안된 google_rss 등)는 여기서 한 번에 mop-up.
|
||||||
|
try:
|
||||||
|
mop = score_pending_news(limit=2000)
|
||||||
|
sentiment_summary: dict[str, Any] = {
|
||||||
|
"status": "ok" if mop.error is None else "failed",
|
||||||
|
"fetched": mop.fetched,
|
||||||
|
"scored": mop.scored,
|
||||||
|
"failed": mop.failed,
|
||||||
|
"error": mop.error,
|
||||||
|
}
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
sentiment_summary = {"status": "failed", "error": str(exc)}
|
||||||
|
|
||||||
elapsed = time.time() - start_ts
|
elapsed = time.time() - start_ts
|
||||||
return {
|
return {
|
||||||
"duration_seconds": round(elapsed, 2),
|
"duration_seconds": round(elapsed, 2),
|
||||||
"tickers": reports,
|
"tickers": reports,
|
||||||
"macro": macro_summary,
|
"macro": macro_summary,
|
||||||
|
"sentiment_mop": sentiment_summary,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
173
backend/app/pipelines/match_outcomes.py
Normal file
173
backend/app/pipelines/match_outcomes.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
"""prediction_outcomes 매칭 배치.
|
||||||
|
|
||||||
|
평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의
|
||||||
|
확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, 매칭 미해결 예측을 실제
|
||||||
|
종가와 매칭한다.
|
||||||
|
|
||||||
|
이월/공휴일 정책:
|
||||||
|
target_date 가 calendar date 라서 비거래일이면 ohlcv_daily 에 행이 없다.
|
||||||
|
그래서 `target_date <= today` 인 미해결 행을 전부 후보로 잡고, 각 행마다
|
||||||
|
`target_date <= ohlcv_daily.date <= today` 범위의 최초 거래일 종가로
|
||||||
|
매칭한다 (=다음 거래일로 자동 이월).
|
||||||
|
|
||||||
|
shadow prediction 도 같은 방식으로 매칭한다 (user_triggered 필터 없음).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# direction_hit 판정 시 ±0.3% 이내는 flat. (features 의 FLAT_BAND 와 동일)
|
||||||
|
FLAT_BAND = 0.003
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchSummary:
|
||||||
|
today: str
|
||||||
|
candidates: int
|
||||||
|
matched: int
|
||||||
|
skipped_no_actual: int
|
||||||
|
already_resolved: int
|
||||||
|
|
||||||
|
|
||||||
|
def _direction_label(ret: float) -> str:
|
||||||
|
if ret > FLAT_BAND:
|
||||||
|
return "up"
|
||||||
|
if ret < -FLAT_BAND:
|
||||||
|
return "down"
|
||||||
|
return "flat"
|
||||||
|
|
||||||
|
|
||||||
|
def match_up_to(today: date) -> MatchSummary:
|
||||||
|
"""target_date <= today 인 모든 미해결 예측을 매칭.
|
||||||
|
|
||||||
|
각 행마다 ohlcv_daily 에서 target_date 이상, today 이하 범위의 최초
|
||||||
|
거래일 종가를 actual_close 로 사용 — 공휴일/주말 이월 자연 처리.
|
||||||
|
"""
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.begin() as conn:
|
||||||
|
candidate_rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT p.id, p.code, p.base_date, p.target_date, p.horizon,
|
||||||
|
p.point_forecast, p.direction, p.model
|
||||||
|
FROM predictions p
|
||||||
|
LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id
|
||||||
|
WHERE p.target_date <= :today
|
||||||
|
AND po.prediction_id IS NULL
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"today": today},
|
||||||
|
).all()
|
||||||
|
candidates = len(candidate_rows)
|
||||||
|
if not candidates:
|
||||||
|
return MatchSummary(str(today), 0, 0, 0, 0)
|
||||||
|
|
||||||
|
matched = 0
|
||||||
|
skipped = 0
|
||||||
|
already = 0
|
||||||
|
for pid, code, base_date, target_date, horizon, point_forecast, pred_dir, model in candidate_rows:
|
||||||
|
# 첫 거래일 종가 (target_date <= date <= today)
|
||||||
|
actual_row = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT date, close FROM ohlcv_daily
|
||||||
|
WHERE code = :c AND date >= :td AND date <= :today
|
||||||
|
ORDER BY date ASC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "td": target_date, "today": today},
|
||||||
|
).first()
|
||||||
|
if not actual_row or actual_row[1] is None:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
actual_date = actual_row[0]
|
||||||
|
actual = float(actual_row[1])
|
||||||
|
|
||||||
|
base_close_row = conn.execute(
|
||||||
|
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
||||||
|
{"c": code, "d": base_date},
|
||||||
|
).first()
|
||||||
|
if not base_close_row or base_close_row[0] is None:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
base_close = float(base_close_row[0])
|
||||||
|
|
||||||
|
actual_ret = actual / base_close - 1.0
|
||||||
|
actual_dir = _direction_label(actual_ret)
|
||||||
|
dir_hit = (pred_dir == actual_dir)
|
||||||
|
abs_err = abs(float(point_forecast) - actual) if point_forecast is not None else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO prediction_outcomes
|
||||||
|
(prediction_id, code, target_date, horizon, model,
|
||||||
|
predicted_close, actual_close, actual_return, direction_hit, abs_error)
|
||||||
|
VALUES
|
||||||
|
(:pid, :code, :d, :h, :m, :pc, :ac, :ar, :dh, :ae)
|
||||||
|
ON CONFLICT (prediction_id) DO NOTHING
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"pid": pid,
|
||||||
|
"code": code,
|
||||||
|
# 실제 매칭된 거래일 (이월된 경우 target_date 와 다를 수 있음)
|
||||||
|
"d": actual_date,
|
||||||
|
"h": horizon,
|
||||||
|
"m": model,
|
||||||
|
"pc": float(point_forecast) if point_forecast is not None else None,
|
||||||
|
"ac": actual,
|
||||||
|
"ar": float(actual_ret),
|
||||||
|
"dh": bool(dir_hit),
|
||||||
|
"ae": float(abs_err) if abs_err is not None else None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
matched += 1
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("match insert failed pid=%s: %s", pid, exc)
|
||||||
|
already += 1
|
||||||
|
|
||||||
|
return MatchSummary(
|
||||||
|
today=str(today),
|
||||||
|
candidates=candidates,
|
||||||
|
matched=matched,
|
||||||
|
skipped_no_actual=skipped,
|
||||||
|
already_resolved=already,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 하위 호환 alias — 이전 시그니처를 쓰던 호출자 (예: 단일 날짜 매칭 테스트)
|
||||||
|
def match_for_date(d: date) -> MatchSummary:
|
||||||
|
"""legacy: target_date == d 만 매칭하던 동작 → 이제 target_date <= d 전체 처리."""
|
||||||
|
return match_up_to(d)
|
||||||
|
|
||||||
|
|
||||||
|
def match_today() -> dict[str, Any]:
|
||||||
|
"""평일 16:30 KST 호출용. target_date <= today (KST) 인 미해결 행 일괄 매칭."""
|
||||||
|
from datetime import datetime, timezone, timedelta as td
|
||||||
|
|
||||||
|
kst = timezone(td(hours=9))
|
||||||
|
today = datetime.now(kst).date()
|
||||||
|
summary = match_up_to(today)
|
||||||
|
return {
|
||||||
|
"today": str(today),
|
||||||
|
"summary": summary.__dict__,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
out = match_today()
|
||||||
|
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|
||||||
267
backend/app/pipelines/predict_one.py
Normal file
267
backend/app/pipelines/predict_one.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""On-demand 예측 + DB 적재.
|
||||||
|
|
||||||
|
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
|
||||||
|
- ensemble.predict() 로 horizons (1,3,5) 결과 계산
|
||||||
|
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
|
||||||
|
(주말만 스킵하는 단순 카운트. 공휴일은 match_outcomes 가 "target_date 이후
|
||||||
|
최초 거래일 종가"로 자동 이월하여 보정.)
|
||||||
|
|
||||||
|
세 종류의 행을 함께 저장한다:
|
||||||
|
- model='ensemble' : 사용자에게 보여주는 최종 예측. user_triggered 플래그 따라감.
|
||||||
|
- model='chronos' : Chronos 단독 (shadow). user_triggered=FALSE 로 항상 적재.
|
||||||
|
- model='lgbm' : LGBM 단독 (shadow). user_triggered=FALSE 로 항상 적재.
|
||||||
|
|
||||||
|
shadow 행은 retrain_weekly 가 모델별 hit_rate 를 비교해 ensemble_weights 를
|
||||||
|
자동 보정하는 입력이 된다.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import asdict
|
||||||
|
from datetime import date, datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
from app.models.ensemble import EnsemblePrediction, predict as ensemble_predict
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
KST = timezone(timedelta(hours=9))
|
||||||
|
|
||||||
|
# ±0.3% flat band — features.FLAT_BAND, match_outcomes.FLAT_BAND 와 동일.
|
||||||
|
FLAT_BAND = 0.003
|
||||||
|
|
||||||
|
|
||||||
|
def _next_trading_target(base_date: date, horizon: int) -> date:
|
||||||
|
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
|
||||||
|
d = base_date
|
||||||
|
added = 0
|
||||||
|
while added < horizon:
|
||||||
|
d = d + timedelta(days=1)
|
||||||
|
if d.weekday() < 5: # 0..4 = Mon..Fri
|
||||||
|
added += 1
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def _last_trading_date(code: str) -> date | None:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(
|
||||||
|
text("SELECT MAX(date) FROM ohlcv_daily WHERE code = :c"),
|
||||||
|
{"c": code},
|
||||||
|
).first()
|
||||||
|
return row[0] if row and row[0] else None
|
||||||
|
|
||||||
|
|
||||||
|
def _direction_label(ret: float) -> str:
|
||||||
|
if ret > FLAT_BAND:
|
||||||
|
return "up"
|
||||||
|
if ret < -FLAT_BAND:
|
||||||
|
return "down"
|
||||||
|
return "flat"
|
||||||
|
|
||||||
|
|
||||||
|
_INSERT_PREDICTION_SQL = text(
|
||||||
|
"""
|
||||||
|
INSERT INTO predictions
|
||||||
|
(code, predicted_at, base_date, target_date, horizon, model,
|
||||||
|
direction, prob_up, prob_flat, prob_down, expected_return,
|
||||||
|
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
|
||||||
|
VALUES
|
||||||
|
(:code, :predicted_at, :base_date, :target_date, :horizon, :model,
|
||||||
|
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
|
||||||
|
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
|
||||||
|
ON CONFLICT (code, base_date, target_date, horizon, model)
|
||||||
|
DO UPDATE SET
|
||||||
|
predicted_at = EXCLUDED.predicted_at,
|
||||||
|
direction = EXCLUDED.direction,
|
||||||
|
prob_up = EXCLUDED.prob_up,
|
||||||
|
prob_flat = EXCLUDED.prob_flat,
|
||||||
|
prob_down = EXCLUDED.prob_down,
|
||||||
|
expected_return = EXCLUDED.expected_return,
|
||||||
|
point_forecast = EXCLUDED.point_forecast,
|
||||||
|
ci_low = EXCLUDED.ci_low,
|
||||||
|
ci_high = EXCLUDED.ci_high,
|
||||||
|
features_snapshot = EXCLUDED.features_snapshot,
|
||||||
|
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
|
||||||
|
RETURNING id
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_prediction(conn, *, model: str, code: str, predicted_at: datetime,
|
||||||
|
base_date: date, target_date: date, horizon: int,
|
||||||
|
direction: str, p_up: float, p_fl: float, p_dn: float,
|
||||||
|
expected_return: float, point: float | None,
|
||||||
|
lo: float | None, hi: float | None,
|
||||||
|
features_snap: dict, user_triggered: bool) -> int | None:
|
||||||
|
row = conn.execute(
|
||||||
|
_INSERT_PREDICTION_SQL,
|
||||||
|
{
|
||||||
|
"code": code,
|
||||||
|
"predicted_at": predicted_at,
|
||||||
|
"base_date": base_date,
|
||||||
|
"target_date": target_date,
|
||||||
|
"horizon": horizon,
|
||||||
|
"model": model,
|
||||||
|
"direction": direction,
|
||||||
|
"p_up": p_up,
|
||||||
|
"p_fl": p_fl,
|
||||||
|
"p_dn": p_dn,
|
||||||
|
"exp_ret": expected_return,
|
||||||
|
"point": point,
|
||||||
|
"lo": lo,
|
||||||
|
"hi": hi,
|
||||||
|
"feats": json.dumps(features_snap),
|
||||||
|
"ut": user_triggered,
|
||||||
|
},
|
||||||
|
).first()
|
||||||
|
return int(row[0]) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def predict_and_store(
|
||||||
|
code: str,
|
||||||
|
*,
|
||||||
|
horizons: tuple[int, ...] = (1, 3, 5),
|
||||||
|
user_triggered: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""앙상블 예측 실행 + predictions 테이블 적재.
|
||||||
|
|
||||||
|
적재 행:
|
||||||
|
- 'ensemble' (user_triggered 인자 반영)
|
||||||
|
- 'chronos' (shadow, user_triggered=FALSE) — Chronos 가 성공했을 때만
|
||||||
|
- 'lgbm' (shadow, user_triggered=FALSE) — LGBM 이 성공한 horizon 만
|
||||||
|
"""
|
||||||
|
base_date = _last_trading_date(code)
|
||||||
|
if base_date is None:
|
||||||
|
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
|
||||||
|
|
||||||
|
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
|
||||||
|
now = datetime.now(KST)
|
||||||
|
base_close = pred.base_close
|
||||||
|
|
||||||
|
eng = get_engine()
|
||||||
|
saved_ids: dict[str, list[int]] = {"ensemble": [], "chronos": [], "lgbm": []}
|
||||||
|
with eng.begin() as conn:
|
||||||
|
for step in pred.steps:
|
||||||
|
target_date = _next_trading_target(base_date, step.horizon)
|
||||||
|
|
||||||
|
# --- ensemble row ---
|
||||||
|
features_snap = {
|
||||||
|
"base_close": base_close,
|
||||||
|
"sources_used": pred.sources_used,
|
||||||
|
"direction": step.direction,
|
||||||
|
}
|
||||||
|
pid = _insert_prediction(
|
||||||
|
conn,
|
||||||
|
model="ensemble",
|
||||||
|
code=code,
|
||||||
|
predicted_at=now,
|
||||||
|
base_date=base_date,
|
||||||
|
target_date=target_date,
|
||||||
|
horizon=step.horizon,
|
||||||
|
direction=step.direction,
|
||||||
|
p_up=step.prob_up,
|
||||||
|
p_fl=step.prob_flat,
|
||||||
|
p_dn=step.prob_down,
|
||||||
|
expected_return=step.expected_return,
|
||||||
|
point=step.point_close,
|
||||||
|
lo=step.ci_low,
|
||||||
|
hi=step.ci_high,
|
||||||
|
features_snap=features_snap,
|
||||||
|
user_triggered=user_triggered,
|
||||||
|
)
|
||||||
|
if pid is not None:
|
||||||
|
saved_ids["ensemble"].append(pid)
|
||||||
|
|
||||||
|
# --- chronos shadow row ---
|
||||||
|
cf = pred.chronos_raw
|
||||||
|
if cf is not None:
|
||||||
|
c_med = float(cf.median[step.horizon - 1])
|
||||||
|
c_q10 = float(cf.q10[step.horizon - 1])
|
||||||
|
c_q90 = float(cf.q90[step.horizon - 1])
|
||||||
|
# direction prob: chronos sample 분포
|
||||||
|
try:
|
||||||
|
import numpy as np
|
||||||
|
arr = np.array(cf.samples)[:, step.horizon - 1]
|
||||||
|
ret = arr / base_close - 1.0
|
||||||
|
cp_up = float((ret > FLAT_BAND).mean())
|
||||||
|
cp_dn = float((ret < -FLAT_BAND).mean())
|
||||||
|
cp_fl = max(0.0, 1.0 - cp_up - cp_dn)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
cp_up = cp_fl = cp_dn = 1.0 / 3.0
|
||||||
|
exp_ret_c = c_med / base_close - 1.0
|
||||||
|
c_dir = _direction_label(exp_ret_c)
|
||||||
|
pid_c = _insert_prediction(
|
||||||
|
conn,
|
||||||
|
model="chronos",
|
||||||
|
code=code,
|
||||||
|
predicted_at=now,
|
||||||
|
base_date=base_date,
|
||||||
|
target_date=target_date,
|
||||||
|
horizon=step.horizon,
|
||||||
|
direction=c_dir,
|
||||||
|
p_up=cp_up,
|
||||||
|
p_fl=cp_fl,
|
||||||
|
p_dn=cp_dn,
|
||||||
|
expected_return=exp_ret_c,
|
||||||
|
point=c_med,
|
||||||
|
lo=c_q10,
|
||||||
|
hi=c_q90,
|
||||||
|
features_snap={"shadow": True, "base_close": base_close},
|
||||||
|
user_triggered=False,
|
||||||
|
)
|
||||||
|
if pid_c is not None:
|
||||||
|
saved_ids["chronos"].append(pid_c)
|
||||||
|
|
||||||
|
# --- lgbm shadow row ---
|
||||||
|
lf = pred.lgbm_raw.get(step.horizon)
|
||||||
|
if lf is not None:
|
||||||
|
l_close = float(lf.predicted_close)
|
||||||
|
exp_ret_l = l_close / base_close - 1.0
|
||||||
|
l_dir = _direction_label(exp_ret_l)
|
||||||
|
pid_l = _insert_prediction(
|
||||||
|
conn,
|
||||||
|
model="lgbm",
|
||||||
|
code=code,
|
||||||
|
predicted_at=now,
|
||||||
|
base_date=base_date,
|
||||||
|
target_date=target_date,
|
||||||
|
horizon=step.horizon,
|
||||||
|
direction=l_dir,
|
||||||
|
p_up=float(lf.prob_up),
|
||||||
|
p_fl=float(lf.prob_flat),
|
||||||
|
p_dn=float(lf.prob_down),
|
||||||
|
expected_return=exp_ret_l,
|
||||||
|
point=l_close,
|
||||||
|
lo=l_close * 0.97,
|
||||||
|
hi=l_close * 1.03,
|
||||||
|
features_snap={"shadow": True, "base_close": base_close},
|
||||||
|
user_triggered=False,
|
||||||
|
)
|
||||||
|
if pid_l is not None:
|
||||||
|
saved_ids["lgbm"].append(pid_l)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"code": code,
|
||||||
|
"base_date": str(base_date),
|
||||||
|
"base_close": pred.base_close,
|
||||||
|
"sources_used": pred.sources_used,
|
||||||
|
"steps": [
|
||||||
|
{
|
||||||
|
**asdict(s),
|
||||||
|
"target_date": str(_next_trading_target(base_date, s.horizon)),
|
||||||
|
}
|
||||||
|
for s in pred.steps
|
||||||
|
],
|
||||||
|
# UI 는 ensemble id 만 본다. shadow 는 디버깅/검증용으로 별도 키.
|
||||||
|
"saved_prediction_ids": saved_ids["ensemble"],
|
||||||
|
"saved_shadow_ids": {
|
||||||
|
"chronos": saved_ids["chronos"],
|
||||||
|
"lgbm": saved_ids["lgbm"],
|
||||||
|
},
|
||||||
|
"user_triggered": user_triggered,
|
||||||
|
}
|
||||||
@@ -36,6 +36,7 @@ class RefreshReport:
|
|||||||
dart: SourceStatus
|
dart: SourceStatus
|
||||||
naver_news: SourceStatus
|
naver_news: SourceStatus
|
||||||
google_rss: SourceStatus
|
google_rss: SourceStatus
|
||||||
|
finbert: SourceStatus
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
out: dict[str, Any] = {"code": self.code}
|
out: dict[str, Any] = {"code": self.code}
|
||||||
@@ -46,6 +47,7 @@ class RefreshReport:
|
|||||||
"dart",
|
"dart",
|
||||||
"naver_news",
|
"naver_news",
|
||||||
"google_rss",
|
"google_rss",
|
||||||
|
"finbert",
|
||||||
):
|
):
|
||||||
v: SourceStatus = getattr(self, f)
|
v: SourceStatus = getattr(self, f)
|
||||||
out[f] = asdict(v)
|
out[f] = asdict(v)
|
||||||
@@ -132,6 +134,25 @@ def _google_rss(code: str, name: str) -> SourceStatus:
|
|||||||
return SourceStatus(status="failed", error=str(exc))
|
return SourceStatus(status="failed", error=str(exc))
|
||||||
|
|
||||||
|
|
||||||
|
def _finbert(code: str) -> SourceStatus:
|
||||||
|
"""방금 upsert 된 뉴스 중 sentiment_score 가 비어있는 행을 KR-FinBERT 로 스코어."""
|
||||||
|
try:
|
||||||
|
from app.nlp.score_news import score_pending_news
|
||||||
|
|
||||||
|
# 한 종목에 대해 신규 뉴스가 매우 많아도 200건으로 컷.
|
||||||
|
# daily_batch 끝에서 잔여분을 별도로 mop-up 한다.
|
||||||
|
res = score_pending_news(code=code, limit=200)
|
||||||
|
return SourceStatus(
|
||||||
|
status="ok" if res.error is None else "failed",
|
||||||
|
inserted=res.scored,
|
||||||
|
skipped=res.failed,
|
||||||
|
extra={"fetched": res.fetched},
|
||||||
|
error=res.error,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
return SourceStatus(status="failed", error=str(exc))
|
||||||
|
|
||||||
|
|
||||||
def refresh_code(code: str, name: str, *, lookback_days: int = 7) -> RefreshReport:
|
def refresh_code(code: str, name: str, *, lookback_days: int = 7) -> RefreshReport:
|
||||||
"""단기 갱신 (daily_batch 용). 최근 lookback_days 만 가져온다."""
|
"""단기 갱신 (daily_batch 용). 최근 lookback_days 만 가져온다."""
|
||||||
end = date.today()
|
end = date.today()
|
||||||
@@ -144,4 +165,5 @@ def refresh_code(code: str, name: str, *, lookback_days: int = 7) -> RefreshRepo
|
|||||||
dart=_dart(code, start, end),
|
dart=_dart(code, start, end),
|
||||||
naver_news=_naver_news(code),
|
naver_news=_naver_news(code),
|
||||||
google_rss=_google_rss(code, name),
|
google_rss=_google_rss(code, name),
|
||||||
|
finbert=_finbert(code),
|
||||||
)
|
)
|
||||||
|
|||||||
207
backend/app/pipelines/retrain_weekly.py
Normal file
207
backend/app/pipelines/retrain_weekly.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""주간 재학습 + 앙상블 가중치 보정.
|
||||||
|
|
||||||
|
일요일 02:00 KST 실행:
|
||||||
|
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
|
||||||
|
2. 최근 30일 prediction_outcomes 의 (code, model, horizon) 별 hit_rate / mae
|
||||||
|
산출, model_performance 적재.
|
||||||
|
3. shadow 행 (model='chronos' / 'lgbm') 의 hit_rate 를 비교해서
|
||||||
|
ensemble_weights 자동 보정.
|
||||||
|
|
||||||
|
가중치 공식:
|
||||||
|
w_c = clamp(0.1, hr_c / (hr_c + hr_l), 0.9)
|
||||||
|
w_l = 1 - w_c
|
||||||
|
단 sample_count_c < MIN_SAMPLE 또는 sample_count_l < MIN_SAMPLE 이면
|
||||||
|
기본값 유지 (DB row 미생성). hr_c + hr_l == 0 (둘 다 0%) 이면 50:50.
|
||||||
|
|
||||||
|
predict_one 이 매 호출마다 chronos/lgbm shadow 행을 함께 적재하고
|
||||||
|
match_outcomes 가 user_triggered 무관하게 매칭하므로, hit_rate 데이터는
|
||||||
|
사용자가 예측을 한 번이라도 본 종목에 대해 자연스럽게 쌓인다.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import date, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
from app.db.connection import get_engine
|
||||||
|
from app.models.lgbm import train_one
|
||||||
|
from app.models.weights import upsert_weights
|
||||||
|
from app.seed.seed_tickers import SEED_TICKERS
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HORIZONS = (1, 3, 5)
|
||||||
|
WINDOW_DAYS = 30
|
||||||
|
MIN_SAMPLE = 10 # 모델당 최소 매칭 표본
|
||||||
|
W_CHRONOS_MIN = 0.1
|
||||||
|
W_CHRONOS_MAX = 0.9
|
||||||
|
|
||||||
|
|
||||||
|
def retrain_all() -> list[dict[str, Any]]:
|
||||||
|
"""시드 10종목 × horizons 학습."""
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
for t in SEED_TICKERS:
|
||||||
|
for h in HORIZONS:
|
||||||
|
try:
|
||||||
|
res = train_one(t.code, h)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.exception("train_one failed for %s h=%d", t.code, h)
|
||||||
|
res = {"code": t.code, "horizon": h, "status": "failed", "error": str(exc)}
|
||||||
|
out.append(res)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def record_performance(as_of: date) -> list[dict[str, Any]]:
|
||||||
|
"""최근 WINDOW_DAYS 의 prediction_outcomes 로 (code, model, horizon) 별
|
||||||
|
hit_rate / mae 산출, model_performance 에 upsert."""
|
||||||
|
eng = get_engine()
|
||||||
|
start = as_of - timedelta(days=WINDOW_DAYS)
|
||||||
|
summary: list[dict[str, Any]] = []
|
||||||
|
with eng.begin() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT code, model, horizon,
|
||||||
|
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||||
|
AVG(abs_error) AS mae,
|
||||||
|
COUNT(*) AS n
|
||||||
|
FROM prediction_outcomes
|
||||||
|
WHERE resolved_at >= :start
|
||||||
|
GROUP BY code, model, horizon
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"start": start},
|
||||||
|
).all()
|
||||||
|
for code, model, horizon, hr, mae, n in rows:
|
||||||
|
conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
INSERT INTO model_performance
|
||||||
|
(code, model, window_days, as_of, hit_rate, mae, sample_count)
|
||||||
|
VALUES (:c, :m, :w, :as_of, :hr, :mae, :n)
|
||||||
|
ON CONFLICT (code, model, window_days, as_of) DO UPDATE SET
|
||||||
|
hit_rate = EXCLUDED.hit_rate,
|
||||||
|
mae = EXCLUDED.mae,
|
||||||
|
sample_count = EXCLUDED.sample_count
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{
|
||||||
|
"c": code,
|
||||||
|
"m": model,
|
||||||
|
"w": WINDOW_DAYS,
|
||||||
|
"as_of": as_of,
|
||||||
|
"hr": float(hr) if hr is not None else None,
|
||||||
|
"mae": float(mae) if mae is not None else None,
|
||||||
|
"n": int(n),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
summary.append(
|
||||||
|
{
|
||||||
|
"code": code,
|
||||||
|
"model": model,
|
||||||
|
"horizon": horizon,
|
||||||
|
"hit_rate": float(hr) if hr is not None else None,
|
||||||
|
"mae": float(mae) if mae is not None else None,
|
||||||
|
"n": int(n),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_weights(as_of: date) -> list[dict[str, Any]]:
|
||||||
|
"""shadow chronos/lgbm hit_rate 로 ensemble_weights 자동 보정.
|
||||||
|
|
||||||
|
반환: (code, horizon, w_chronos, w_lgbm, hr_c, hr_l, n_c, n_l, action) 의
|
||||||
|
리스트. action ∈ {'updated', 'skipped_insufficient', 'skipped_zero'}.
|
||||||
|
"""
|
||||||
|
eng = get_engine()
|
||||||
|
start = as_of - timedelta(days=WINDOW_DAYS)
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
with eng.begin() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT code, horizon, model,
|
||||||
|
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||||
|
COUNT(*) AS n
|
||||||
|
FROM prediction_outcomes
|
||||||
|
WHERE resolved_at >= :start
|
||||||
|
AND model IN ('chronos', 'lgbm')
|
||||||
|
GROUP BY code, horizon, model
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"start": start},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# (code, horizon) -> {'chronos': (hr, n), 'lgbm': (hr, n)}
|
||||||
|
agg: dict[tuple[str, int], dict[str, tuple[float, int]]] = {}
|
||||||
|
for code, horizon, model, hr, n in rows:
|
||||||
|
key = (code, int(horizon))
|
||||||
|
agg.setdefault(key, {})[str(model)] = (
|
||||||
|
float(hr) if hr is not None else 0.0,
|
||||||
|
int(n),
|
||||||
|
)
|
||||||
|
|
||||||
|
for (code, horizon), m in agg.items():
|
||||||
|
c = m.get("chronos")
|
||||||
|
l = m.get("lgbm")
|
||||||
|
if c is None or l is None:
|
||||||
|
out.append({
|
||||||
|
"code": code, "horizon": horizon,
|
||||||
|
"action": "skipped_missing_model",
|
||||||
|
"have": list(m.keys()),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
hr_c, n_c = c
|
||||||
|
hr_l, n_l = l
|
||||||
|
if n_c < MIN_SAMPLE or n_l < MIN_SAMPLE:
|
||||||
|
out.append({
|
||||||
|
"code": code, "horizon": horizon,
|
||||||
|
"hr_chronos": hr_c, "hr_lgbm": hr_l,
|
||||||
|
"n_chronos": n_c, "n_lgbm": n_l,
|
||||||
|
"action": "skipped_insufficient",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
total = hr_c + hr_l
|
||||||
|
if total <= 0:
|
||||||
|
w_c = 0.5
|
||||||
|
else:
|
||||||
|
w_c = hr_c / total
|
||||||
|
w_c = max(W_CHRONOS_MIN, min(W_CHRONOS_MAX, w_c))
|
||||||
|
w_l = 1.0 - w_c
|
||||||
|
upsert_weights(
|
||||||
|
code, horizon, w_c, w_l,
|
||||||
|
hit_rate_chronos=hr_c, hit_rate_lgbm=hr_l,
|
||||||
|
sample_count=min(n_c, n_l),
|
||||||
|
)
|
||||||
|
out.append({
|
||||||
|
"code": code, "horizon": horizon,
|
||||||
|
"hr_chronos": hr_c, "hr_lgbm": hr_l,
|
||||||
|
"n_chronos": n_c, "n_lgbm": n_l,
|
||||||
|
"w_chronos": w_c, "w_lgbm": w_l,
|
||||||
|
"action": "updated",
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def run_weekly() -> dict[str, Any]:
|
||||||
|
"""일요일 02:00 KST 호출 entry-point."""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
kst = timezone(timedelta(hours=9))
|
||||||
|
as_of = datetime.now(kst).date()
|
||||||
|
return {
|
||||||
|
"as_of": str(as_of),
|
||||||
|
"trained": retrain_all(),
|
||||||
|
"performance": record_performance(as_of),
|
||||||
|
"weights": adjust_weights(as_of),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||||
|
out = run_weekly()
|
||||||
|
print(json.dumps(out, ensure_ascii=False, indent=2, default=str))
|
||||||
@@ -22,6 +22,8 @@ from apscheduler.triggers.cron import CronTrigger
|
|||||||
from pytz import timezone
|
from pytz import timezone
|
||||||
|
|
||||||
from app.pipelines.daily_batch import run_daily_batch
|
from app.pipelines.daily_batch import run_daily_batch
|
||||||
|
from app.pipelines.match_outcomes import match_today
|
||||||
|
from app.pipelines.retrain_weekly import run_weekly
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
KST = timezone("Asia/Seoul")
|
KST = timezone("Asia/Seoul")
|
||||||
@@ -34,15 +36,34 @@ def start_scheduler() -> BackgroundScheduler:
|
|||||||
if _scheduler:
|
if _scheduler:
|
||||||
return _scheduler
|
return _scheduler
|
||||||
_scheduler = BackgroundScheduler(timezone=KST)
|
_scheduler = BackgroundScheduler(timezone=KST)
|
||||||
|
# 16:00 평일: 시드 10종목 EOD/뉴스/공시/거시 갱신
|
||||||
_scheduler.add_job(
|
_scheduler.add_job(
|
||||||
run_daily_batch,
|
run_daily_batch,
|
||||||
CronTrigger(hour=16, minute=0, timezone=KST),
|
CronTrigger(day_of_week="mon-fri", hour=16, minute=0, timezone=KST),
|
||||||
id="daily_batch_16",
|
id="daily_batch_16",
|
||||||
replace_existing=True,
|
replace_existing=True,
|
||||||
max_instances=1,
|
max_instances=1,
|
||||||
)
|
)
|
||||||
|
# 16:30 평일: prediction_outcomes 매칭 배치
|
||||||
|
_scheduler.add_job(
|
||||||
|
match_today,
|
||||||
|
CronTrigger(day_of_week="mon-fri", hour=16, minute=30, timezone=KST),
|
||||||
|
id="match_outcomes_1630",
|
||||||
|
replace_existing=True,
|
||||||
|
max_instances=1,
|
||||||
|
)
|
||||||
|
# 일요일 02:00: LGBM 재학습 + 성능 기록
|
||||||
|
_scheduler.add_job(
|
||||||
|
run_weekly,
|
||||||
|
CronTrigger(day_of_week="sun", hour=2, minute=0, timezone=KST),
|
||||||
|
id="retrain_weekly_sun_0200",
|
||||||
|
replace_existing=True,
|
||||||
|
max_instances=1,
|
||||||
|
)
|
||||||
_scheduler.start()
|
_scheduler.start()
|
||||||
logger.info("scheduler started (daily_batch @ 16:00 KST)")
|
logger.info(
|
||||||
|
"scheduler started: daily_batch(16:00 mon-fri), match_outcomes(16:30 mon-fri), retrain_weekly(sun 02:00) KST"
|
||||||
|
)
|
||||||
return _scheduler
|
return _scheduler
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,9 +31,12 @@ dependencies = [
|
|||||||
"transformers==4.41.2",
|
"transformers==4.41.2",
|
||||||
"tokenizers==0.19.1",
|
"tokenizers==0.19.1",
|
||||||
"sentencepiece==0.2.0",
|
"sentencepiece==0.2.0",
|
||||||
|
"accelerate==0.34.2",
|
||||||
|
"chronos-forecasting==1.4.1",
|
||||||
"scikit-learn==1.5.0",
|
"scikit-learn==1.5.0",
|
||||||
"lightgbm==4.3.0",
|
"lightgbm==4.3.0",
|
||||||
"ta==0.11.0",
|
"ta==0.11.0",
|
||||||
|
"joblib==1.4.2",
|
||||||
|
|
||||||
# scheduler
|
# scheduler
|
||||||
"apscheduler==3.10.4",
|
"apscheduler==3.10.4",
|
||||||
@@ -63,3 +66,19 @@ include = ["app*"]
|
|||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 100
|
line-length = 100
|
||||||
target-version = "py311"
|
target-version = "py311"
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
strict_optional = true
|
||||||
|
warn_unused_ignores = true
|
||||||
|
warn_redundant_casts = true
|
||||||
|
warn_unreachable = true
|
||||||
|
ignore_missing_imports = true # 3rd-party stub 부족 무시 (pykrx, chronos, ta, feedparser ...)
|
||||||
|
no_implicit_optional = true
|
||||||
|
show_error_codes = true
|
||||||
|
exclude = ["build", "dist", ".venv"]
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = ["app.*"]
|
||||||
|
disallow_untyped_defs = false # 점진적 도입. 핵심 모듈만 strict 로 올릴 예정.
|
||||||
|
check_untyped_defs = true
|
||||||
|
|||||||
@@ -8,5 +8,8 @@ services:
|
|||||||
count: all
|
count: all
|
||||||
capabilities: [gpu]
|
capabilities: [gpu]
|
||||||
environment:
|
environment:
|
||||||
MODEL_DEVICE: cuda
|
# MODEL_DEVICE 는 .env 로 덮어쓰기 가능. GPU 빌드라도 PyTorch/CUDA 호환 문제 (예:
|
||||||
|
# 'no kernel image is available for execution on the device') 발생 시 .env 에
|
||||||
|
# MODEL_DEVICE=cpu 를 두고 `docker compose ... up -d backend` 로 회피.
|
||||||
|
MODEL_DEVICE: ${MODEL_DEVICE:-cuda}
|
||||||
NVIDIA_VISIBLE_DEVICES: all
|
NVIDIA_VISIBLE_DEVICES: all
|
||||||
|
|||||||
57
restart-ci.bat
Normal file
57
restart-ci.bat
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
@echo off
|
||||||
|
REM stock_chart_site - SSH/CI 친화 재시작 스크립트
|
||||||
|
REM
|
||||||
|
REM restart.bat 과의 차이: pause 가 없음. SSH 비대화형 (예: ssh user@host "restart-ci.bat")
|
||||||
|
REM 에서 멈추지 않고 끝까지 실행. 에러는 종료 코드로만 알린다.
|
||||||
|
REM
|
||||||
|
REM 일반 사용 시엔 restart.bat 을 쓰는게 출력 검토에 편하다.
|
||||||
|
|
||||||
|
setlocal enabledelayedexpansion
|
||||||
|
cd /d "%~dp0"
|
||||||
|
|
||||||
|
echo === stock_chart_site restart-ci ===
|
||||||
|
|
||||||
|
where docker >nul 2>&1
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo [ERROR] docker not found
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
docker info >nul 2>&1
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo [ERROR] Docker Desktop not running
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
set USE_GPU=0
|
||||||
|
where nvidia-smi >nul 2>&1
|
||||||
|
if not errorlevel 1 (
|
||||||
|
nvidia-smi >nul 2>&1
|
||||||
|
if not errorlevel 1 set USE_GPU=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if "%USE_GPU%"=="1" (
|
||||||
|
echo [GPU] using GPU profile
|
||||||
|
set COMPOSE_FILES=-f docker-compose.yml -f docker-compose.gpu.yml
|
||||||
|
) else (
|
||||||
|
echo [CPU] using CPU profile
|
||||||
|
set COMPOSE_FILES=-f docker-compose.yml
|
||||||
|
)
|
||||||
|
|
||||||
|
for /f %%i in ('docker compose %COMPOSE_FILES% ps --status running --quiet backend web 2^>nul ^| find /v /c ""') do set RUN_COUNT=%%i
|
||||||
|
if "%RUN_COUNT%"=="0" (
|
||||||
|
echo [ERROR] backend/web not running. run build.bat first.
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo === docker compose up -d --force-recreate --no-deps backend web ===
|
||||||
|
docker compose %COMPOSE_FILES% up -d --force-recreate --no-deps backend web
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo [ERROR] restart failed
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo === status ===
|
||||||
|
docker compose %COMPOSE_FILES% ps
|
||||||
|
|
||||||
|
endlocal
|
||||||
|
exit /b 0
|
||||||
94
restart.bat
Normal file
94
restart.bat
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
@echo off
|
||||||
|
REM stock_chart_site - Windows 재시작 스크립트
|
||||||
|
REM
|
||||||
|
REM build.bat 와의 차이:
|
||||||
|
REM - build.bat: 이미지 재빌드 포함. Dockerfile / pyproject.toml / package*.json /
|
||||||
|
REM compose 설정 등 의존성/이미지 구성이 바뀌었을 때 사용.
|
||||||
|
REM - restart.bat: 재빌드 없이 컨테이너만 재시작. backend/app/ 또는 web/app/ 안의
|
||||||
|
REM 코드만 바뀐 경우. docker-compose.yml 의 바인드 마운트 (./backend:/app,
|
||||||
|
REM ./web:/app) 덕에 새 코드가 즉시 컨테이너 안에서 보이고, 재시작으로
|
||||||
|
REM lifespan (부팅 시드 등) 도 다시 돌릴 수 있다.
|
||||||
|
REM
|
||||||
|
REM 즉 일반적으로 git pull 후:
|
||||||
|
REM - pyproject.toml / Dockerfile / package*.json 변경 있음 → build.bat
|
||||||
|
REM - app/ 코드만 변경 → restart.bat
|
||||||
|
|
||||||
|
setlocal enabledelayedexpansion
|
||||||
|
cd /d "%~dp0"
|
||||||
|
|
||||||
|
echo === stock_chart_site restart ===
|
||||||
|
|
||||||
|
REM 1) Docker 확인
|
||||||
|
where docker >nul 2>&1
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo [ERROR] docker 명령을 찾을 수 없습니다. Docker Desktop 설치/실행을 확인하세요.
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
docker info >nul 2>&1
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo [ERROR] Docker Desktop이 실행 중이 아닙니다.
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
REM 2) GPU 감지 (build.bat 과 동일 — compose 파일 조합 일치 위해)
|
||||||
|
set USE_GPU=0
|
||||||
|
where nvidia-smi >nul 2>&1
|
||||||
|
if not errorlevel 1 (
|
||||||
|
nvidia-smi >nul 2>&1
|
||||||
|
if not errorlevel 1 set USE_GPU=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if "%USE_GPU%"=="1" (
|
||||||
|
echo [GPU] NVIDIA GPU detected. Using GPU profile.
|
||||||
|
set COMPOSE_FILES=-f docker-compose.yml -f docker-compose.gpu.yml
|
||||||
|
) else (
|
||||||
|
echo [CPU] NVIDIA GPU not detected. Using CPU profile.
|
||||||
|
set COMPOSE_FILES=-f docker-compose.yml
|
||||||
|
)
|
||||||
|
|
||||||
|
REM 3) backend/web 컨테이너 살아있는지 확인 — 없으면 build.bat 안내
|
||||||
|
REM (db 까지 포함해서 세면 db 만 떠있어도 통과돼버려서 부정확)
|
||||||
|
for /f %%i in ('docker compose %COMPOSE_FILES% ps --status running --quiet backend web 2^>nul ^| find /v /c ""') do set RUN_COUNT=%%i
|
||||||
|
if "%RUN_COUNT%"=="0" (
|
||||||
|
echo [INFO] 실행 중인 backend/web 컨테이너가 없습니다. 처음이거나 down 된 상태입니다.
|
||||||
|
echo build.bat 으로 빌드 + 기동하세요.
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
REM 4) backend + web 만 재기동 — db 는 건드리지 않음 (--no-deps).
|
||||||
|
REM
|
||||||
|
REM 왜 `restart` 가 아니라 `up -d --force-recreate` 인가:
|
||||||
|
REM - `docker compose restart` 는 기존 컨테이너를 stop/start 만 한다. 그래서
|
||||||
|
REM `.env` 변경 (예: KIS_APP_KEY 갱신) 이 반영되지 않는다. env_file 은
|
||||||
|
REM 컨테이너 "생성" 시점에만 읽힌다.
|
||||||
|
REM - `up -d --force-recreate` 는 새 컨테이너 인스턴스를 만들어서 env_file 을
|
||||||
|
REM 다시 읽는다. 이게 사용자가 .env 만 고치고 restart.bat 돌렸을 때 직관에 맞는다.
|
||||||
|
REM - `--no-deps` 로 db 는 절대 건드리지 않음. db 는 postgres_data 볼륨이 영속이라
|
||||||
|
REM 재기동할 이유 없고, depends_on.condition: service_healthy 와 무관하게 안전.
|
||||||
|
echo.
|
||||||
|
echo === docker compose up -d --force-recreate --no-deps backend web ===
|
||||||
|
docker compose %COMPOSE_FILES% up -d --force-recreate --no-deps backend web
|
||||||
|
if errorlevel 1 (
|
||||||
|
echo [ERROR] restart 실패.
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo === 상태 ===
|
||||||
|
docker compose %COMPOSE_FILES% ps
|
||||||
|
|
||||||
|
echo.
|
||||||
|
echo 접속:
|
||||||
|
echo Web http://localhost:3000
|
||||||
|
echo Backend http://localhost:8000/health
|
||||||
|
echo DB ext http://localhost:8000/health/db
|
||||||
|
echo.
|
||||||
|
echo 로그 보기: docker compose logs -f backend
|
||||||
|
echo 정지: docker compose down
|
||||||
|
echo.
|
||||||
|
pause
|
||||||
|
endlocal
|
||||||
6
web/.eslintrc.json
Normal file
6
web/.eslintrc.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"extends": "next/core-web-vitals",
|
||||||
|
"rules": {
|
||||||
|
"react-hooks/exhaustive-deps": "warn"
|
||||||
|
}
|
||||||
|
}
|
||||||
158
web/app/[code]/page.tsx
Normal file
158
web/app/[code]/page.tsx
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import Link from "next/link";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { MetricsPanel } from "../../components/MetricsPanel";
|
||||||
|
import { NewsList } from "../../components/NewsList";
|
||||||
|
import { PredictionPanel } from "../../components/PredictionPanel";
|
||||||
|
import { StockChart } from "../../components/StockChart";
|
||||||
|
import {
|
||||||
|
api,
|
||||||
|
type ChartInterval,
|
||||||
|
type ChartPayload,
|
||||||
|
type LatestPredictionResponse,
|
||||||
|
} from "../../lib/api";
|
||||||
|
|
||||||
|
const INTERVALS: { label: string; value: ChartInterval; defaultDays: number }[] = [
|
||||||
|
{ label: "10분", value: "10m", defaultDays: 1 },
|
||||||
|
{ label: "일", value: "1d", defaultDays: 180 },
|
||||||
|
{ label: "주", value: "1w", defaultDays: 365 * 2 },
|
||||||
|
{ label: "월", value: "1mo", defaultDays: 365 * 5 },
|
||||||
|
];
|
||||||
|
|
||||||
|
export default function CodePage({ params }: { params: { code: string } }) {
|
||||||
|
const { code } = params;
|
||||||
|
const [chart, setChart] = useState<ChartPayload | null>(null);
|
||||||
|
const [prediction, setPrediction] = useState<LatestPredictionResponse | null>(null);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
const [interval, setIntervalKind] = useState<ChartInterval>("1d");
|
||||||
|
const [days, setDays] = useState(180);
|
||||||
|
|
||||||
|
// interval 바꾸면 days 도 그 interval 에 맞는 기본값으로 (사용자가 명시적으로 다시 고를 수 있게).
|
||||||
|
function pickInterval(v: ChartInterval) {
|
||||||
|
const meta = INTERVALS.find((i) => i.value === v)!;
|
||||||
|
setIntervalKind(v);
|
||||||
|
setDays(meta.defaultDays);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 초기/주기적 차트 로드. 10분봉이면 60초마다 폴링 — 백엔드가 캐시-then-fetch 로
|
||||||
|
// 10분 이내면 DB 만 읽고, 넘었으면 KIS 호출. 폴링 부담은 낮음.
|
||||||
|
useEffect(() => {
|
||||||
|
let alive = true;
|
||||||
|
setErr(null);
|
||||||
|
setChart(null);
|
||||||
|
|
||||||
|
const load = () => {
|
||||||
|
api
|
||||||
|
.getChart(code, days, interval)
|
||||||
|
.then((c) => {
|
||||||
|
if (alive) setChart(c);
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
if (alive) setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
});
|
||||||
|
};
|
||||||
|
load();
|
||||||
|
|
||||||
|
if (interval === "10m") {
|
||||||
|
const h = window.setInterval(load, 60_000);
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
window.clearInterval(h);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
};
|
||||||
|
}, [code, days, interval]);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let alive = true;
|
||||||
|
api
|
||||||
|
.latestPrediction(code)
|
||||||
|
.then((r) => {
|
||||||
|
if (alive && r.found) setPrediction(r);
|
||||||
|
})
|
||||||
|
.catch(() => {
|
||||||
|
// 예측 이력 없는 경우는 무시.
|
||||||
|
});
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
};
|
||||||
|
}, [code]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<main className="mx-auto max-w-5xl px-6 py-10">
|
||||||
|
<div className="mb-4 flex items-center justify-between">
|
||||||
|
<Link href="/" className="text-xs text-zinc-500 hover:text-zinc-300">
|
||||||
|
← 검색으로
|
||||||
|
</Link>
|
||||||
|
<div className="flex items-center gap-2">
|
||||||
|
<div className="flex overflow-hidden rounded-md border border-zinc-700 text-xs">
|
||||||
|
{INTERVALS.map((it) => (
|
||||||
|
<button
|
||||||
|
key={it.value}
|
||||||
|
onClick={() => pickInterval(it.value)}
|
||||||
|
className={
|
||||||
|
interval === it.value
|
||||||
|
? "bg-emerald-700 px-3 py-1 text-white"
|
||||||
|
: "bg-zinc-900 px-3 py-1 text-zinc-300 hover:bg-zinc-800"
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{it.label}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
{interval !== "10m" && (
|
||||||
|
<select
|
||||||
|
value={days}
|
||||||
|
onChange={(e) => setDays(Number(e.target.value))}
|
||||||
|
className="rounded-md border border-zinc-700 bg-zinc-900 px-2 py-1 text-xs"
|
||||||
|
>
|
||||||
|
<option value={60}>최근 3개월</option>
|
||||||
|
<option value={180}>최근 6개월</option>
|
||||||
|
<option value={365}>최근 1년</option>
|
||||||
|
<option value={365 * 2}>최근 2년</option>
|
||||||
|
<option value={365 * 5}>최근 5년</option>
|
||||||
|
<option value={365 * 10}>최근 10년</option>
|
||||||
|
</select>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{chart && (
|
||||||
|
<div className="mb-4 flex items-baseline justify-between">
|
||||||
|
<h1 className="text-2xl font-semibold text-zinc-100">
|
||||||
|
{chart.name}{" "}
|
||||||
|
<span className="text-sm font-normal text-zinc-500">
|
||||||
|
{chart.code} · {chart.market}
|
||||||
|
</span>
|
||||||
|
</h1>
|
||||||
|
{interval === "10m" && (
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
실시간 10분봉 · 60초마다 갱신
|
||||||
|
{chart.intraday_status && (
|
||||||
|
<span className="ml-2 text-zinc-600">[{chart.intraday_status}]</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
|
||||||
|
{err && <div className="mb-4 text-sm text-red-400">차트 로딩 실패: {err}</div>}
|
||||||
|
|
||||||
|
{chart && (
|
||||||
|
<>
|
||||||
|
<StockChart chart={chart} prediction={prediction} />
|
||||||
|
<div className="mt-6">
|
||||||
|
<PredictionPanel code={code} initial={prediction} onResult={setPrediction} />
|
||||||
|
</div>
|
||||||
|
<div className="mt-6 grid gap-6 md:grid-cols-2">
|
||||||
|
<MetricsPanel code={code} />
|
||||||
|
<NewsList code={code} />
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
|
</main>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,13 +1,32 @@
|
|||||||
|
import { SearchBox } from "../components/SearchBox";
|
||||||
|
|
||||||
export default function HomePage() {
|
export default function HomePage() {
|
||||||
return (
|
return (
|
||||||
<main className="mx-auto max-w-3xl px-6 py-16">
|
<main className="mx-auto max-w-3xl px-6 py-16">
|
||||||
<h1 className="text-3xl font-bold tracking-tight">Stock Chart Site</h1>
|
<h1 className="text-3xl font-bold tracking-tight">Stock Chart Site</h1>
|
||||||
<p className="mt-3 text-sm text-zinc-400">
|
<p className="mt-2 text-sm text-zinc-400">
|
||||||
Phase 0 scaffold. 종목 검색 UI는 Phase 6에서 추가됩니다.
|
종목을 검색해 현재 차트를 보고, <b>예상차트 보기</b> 버튼으로 Chronos + LightGBM
|
||||||
|
앙상블의 단기(1·3·5거래일) 예측을 차트에 이어 붙입니다.
|
||||||
</p>
|
</p>
|
||||||
<div className="mt-8 rounded-md border border-zinc-800 bg-zinc-900/50 p-4 text-sm">
|
|
||||||
<div className="font-medium">Backend health</div>
|
<div className="mt-8">
|
||||||
<code className="mt-2 block text-zinc-400">GET {process.env.NEXT_PUBLIC_API_BASE}/health</code>
|
<SearchBox autoFocus />
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="mt-12 grid gap-3 text-xs text-zinc-500 sm:grid-cols-2">
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/30 p-4">
|
||||||
|
<div className="font-medium text-zinc-300">학습 대상 10종목</div>
|
||||||
|
<div className="mt-1">
|
||||||
|
삼성전자, SK하이닉스, 에코프로비엠, 한미반도체, 두산에너빌리티, 한화에어로스페이스,
|
||||||
|
HD현대중공업, NAVER, KT&G, 한국가스공사
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/30 p-4">
|
||||||
|
<div className="font-medium text-zinc-300">매칭/재학습</div>
|
||||||
|
<div className="mt-1">
|
||||||
|
평일 16:30 KST 에 예측과 실제 종가 매칭, 일요일 02:00 KST 에 LGBM 재학습.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</main>
|
</main>
|
||||||
);
|
);
|
||||||
|
|||||||
64
web/components/MetricsPanel.tsx
Normal file
64
web/components/MetricsPanel.tsx
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { api, type MetricsResponse } from "../lib/api";
|
||||||
|
|
||||||
|
export function MetricsPanel({ code }: { code: string }) {
|
||||||
|
const [m, setM] = useState<MetricsResponse | null>(null);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let alive = true;
|
||||||
|
api
|
||||||
|
.metrics(code, 30)
|
||||||
|
.then((r) => {
|
||||||
|
if (alive) setM(r);
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
if (alive) setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
});
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
};
|
||||||
|
}, [code]);
|
||||||
|
|
||||||
|
if (err) return <div className="text-xs text-red-400">메트릭 로딩 실패: {err}</div>;
|
||||||
|
if (!m) return <div className="text-xs text-zinc-500">메트릭 로딩 중…</div>;
|
||||||
|
|
||||||
|
const rows = m.by_model_horizon ?? [];
|
||||||
|
if (!rows.length) {
|
||||||
|
return (
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
최근 30일 매칭된 예측 결과가 아직 없습니다. (매칭 배치는 평일 16:30 KST 에 실행)
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/40 p-4">
|
||||||
|
<div className="mb-2 text-sm font-medium text-zinc-200">최근 30일 모델 성능</div>
|
||||||
|
<table className="w-full text-left text-sm">
|
||||||
|
<thead className="text-xs text-zinc-500">
|
||||||
|
<tr>
|
||||||
|
<th className="py-1">모델</th>
|
||||||
|
<th>+거래일</th>
|
||||||
|
<th>표본 수</th>
|
||||||
|
<th>방향 적중률</th>
|
||||||
|
<th>평균 절대오차 (MAE)</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody className="divide-y divide-zinc-800">
|
||||||
|
{rows.map((r, i) => (
|
||||||
|
<tr key={i}>
|
||||||
|
<td className="py-1">{r.model}</td>
|
||||||
|
<td>+{r.horizon}</td>
|
||||||
|
<td>{r.n}</td>
|
||||||
|
<td>{r.hit_rate != null ? `${(r.hit_rate * 100).toFixed(1)}%` : "-"}</td>
|
||||||
|
<td>{r.mae != null ? r.mae.toLocaleString(undefined, { maximumFractionDigits: 1 }) : "-"}</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
77
web/components/NewsList.tsx
Normal file
77
web/components/NewsList.tsx
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { api, type NewsResponse } from "../lib/api";
|
||||||
|
|
||||||
|
export function NewsList({ code }: { code: string }) {
|
||||||
|
const [data, setData] = useState<NewsResponse | null>(null);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
let alive = true;
|
||||||
|
api
|
||||||
|
.news(code, 20)
|
||||||
|
.then((r) => {
|
||||||
|
if (alive) setData(r);
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
if (alive) setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
});
|
||||||
|
return () => {
|
||||||
|
alive = false;
|
||||||
|
};
|
||||||
|
}, [code]);
|
||||||
|
|
||||||
|
if (err) return <div className="text-xs text-red-400">뉴스 로딩 실패: {err}</div>;
|
||||||
|
if (!data) return <div className="text-xs text-zinc-500">뉴스 로딩 중…</div>;
|
||||||
|
if (!data.items.length)
|
||||||
|
return <div className="text-xs text-zinc-500">최근 뉴스 없음</div>;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/40 p-4">
|
||||||
|
<div className="mb-2 text-sm font-medium text-zinc-200">최근 뉴스/공시</div>
|
||||||
|
<ul className="divide-y divide-zinc-800">
|
||||||
|
{data.items.map((n, i) => (
|
||||||
|
<li key={i} className="py-2">
|
||||||
|
<a
|
||||||
|
href={n.url}
|
||||||
|
target="_blank"
|
||||||
|
rel="noopener noreferrer"
|
||||||
|
className="block hover:bg-zinc-800/40"
|
||||||
|
>
|
||||||
|
<div className="text-sm text-zinc-100 line-clamp-2">{n.title}</div>
|
||||||
|
<div className="mt-0.5 flex items-center gap-2 text-xs text-zinc-500">
|
||||||
|
<span>{n.source}</span>
|
||||||
|
{n.published_at && <span>· {formatDate(n.published_at)}</span>}
|
||||||
|
{n.sentiment_label && (
|
||||||
|
<span className={sentimentColor(n.sentiment_label)}>
|
||||||
|
· {n.sentiment_label} {n.sentiment_score != null ? `(${n.sentiment_score.toFixed(2)})` : ""}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</a>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function sentimentColor(l: string): string {
|
||||||
|
if (l === "positive") return "text-emerald-400";
|
||||||
|
if (l === "negative") return "text-red-400";
|
||||||
|
return "text-zinc-400";
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatDate(iso: string): string {
|
||||||
|
try {
|
||||||
|
const d = new Date(iso);
|
||||||
|
return `${d.getFullYear()}-${pad(d.getMonth() + 1)}-${pad(d.getDate())} ${pad(d.getHours())}:${pad(d.getMinutes())}`;
|
||||||
|
} catch {
|
||||||
|
return iso;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function pad(n: number): string {
|
||||||
|
return n < 10 ? `0${n}` : `${n}`;
|
||||||
|
}
|
||||||
223
web/components/PredictionPanel.tsx
Normal file
223
web/components/PredictionPanel.tsx
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useState } from "react";
|
||||||
|
import {
|
||||||
|
api,
|
||||||
|
type LatestPredictionResponse,
|
||||||
|
type LatestPredictionStep,
|
||||||
|
type PredictResponse,
|
||||||
|
} from "../lib/api";
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
code: string;
|
||||||
|
initial?: LatestPredictionResponse | null;
|
||||||
|
onResult: (pred: LatestPredictionResponse) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
function normalizeFromPredictResponse(
|
||||||
|
code: string,
|
||||||
|
resp: PredictResponse,
|
||||||
|
): LatestPredictionResponse {
|
||||||
|
const steps: LatestPredictionStep[] = resp.steps.map((s) => ({
|
||||||
|
predicted_at: null,
|
||||||
|
target_date: s.target_date ?? "",
|
||||||
|
horizon: s.horizon,
|
||||||
|
direction: s.direction,
|
||||||
|
prob_up: s.prob_up,
|
||||||
|
prob_flat: s.prob_flat,
|
||||||
|
prob_down: s.prob_down,
|
||||||
|
expected_return: s.expected_return,
|
||||||
|
point_close: s.point_close,
|
||||||
|
ci_low: s.ci_low,
|
||||||
|
ci_high: s.ci_high,
|
||||||
|
user_triggered: resp.user_triggered,
|
||||||
|
features_snapshot: null,
|
||||||
|
}));
|
||||||
|
return {
|
||||||
|
code,
|
||||||
|
found: true,
|
||||||
|
base_date: resp.base_date,
|
||||||
|
base_close: resp.base_close,
|
||||||
|
steps,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const HORIZON_PRESETS: { label: string; value: number[] }[] = [
|
||||||
|
{ label: "단기 (1·3·5)", value: [1, 3, 5] },
|
||||||
|
{ label: "중기 (1·5·10)", value: [1, 5, 10] },
|
||||||
|
{ label: "장기 (5·10·20)", value: [5, 10, 20] },
|
||||||
|
];
|
||||||
|
|
||||||
|
export function PredictionPanel({ code, initial, onResult }: Props) {
|
||||||
|
const [pred, setPred] = useState<LatestPredictionResponse | null>(initial ?? null);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
const [presetIdx, setPresetIdx] = useState(0);
|
||||||
|
const [customRaw, setCustomRaw] = useState("");
|
||||||
|
const [useCustom, setUseCustom] = useState(false);
|
||||||
|
|
||||||
|
function effectiveHorizons(): number[] {
|
||||||
|
if (useCustom) {
|
||||||
|
// 백엔드 predict.py 가 1~30 만 허용 (모델 학습/검증 범위와 일치).
|
||||||
|
// 프론트에서 동일 cap 으로 끊어서 400 안 나게 한다.
|
||||||
|
const parsed = customRaw
|
||||||
|
.split(",")
|
||||||
|
.map((s) => Number(s.trim()))
|
||||||
|
.filter((n) => Number.isFinite(n) && n >= 1 && n <= 30);
|
||||||
|
if (parsed.length > 0) return Array.from(new Set(parsed)).sort((a, b) => a - b);
|
||||||
|
}
|
||||||
|
return HORIZON_PRESETS[presetIdx].value;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function runPredict() {
|
||||||
|
setLoading(true);
|
||||||
|
setErr(null);
|
||||||
|
try {
|
||||||
|
const horizons = effectiveHorizons();
|
||||||
|
const r = await api.predict(code, horizons);
|
||||||
|
const normalized = normalizeFromPredictResponse(code, r);
|
||||||
|
setPred(normalized);
|
||||||
|
onResult(normalized);
|
||||||
|
} catch (e) {
|
||||||
|
setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const steps = pred?.steps ?? [];
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="rounded-md border border-zinc-800 bg-zinc-900/40 p-4">
|
||||||
|
<div className="mb-3 flex items-center justify-between">
|
||||||
|
<div>
|
||||||
|
<div className="text-sm font-medium text-zinc-200">예측 (Chronos + LightGBM 앙상블)</div>
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
클릭한 종목은 자동 저장 후 다음 거래일 장 종료 시 실제 가격과 비교됩니다.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
onClick={runPredict}
|
||||||
|
disabled={loading}
|
||||||
|
className="rounded-md bg-emerald-700 px-4 py-2 text-sm font-medium text-white hover:bg-emerald-600 disabled:opacity-50"
|
||||||
|
>
|
||||||
|
{loading ? "예측 중…" : pred?.found ? "다시 예측" : "예상차트 보기"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="mb-3 flex flex-wrap items-center gap-2 text-xs">
|
||||||
|
<span className="text-zinc-500">예측 거래일:</span>
|
||||||
|
{HORIZON_PRESETS.map((p, i) => (
|
||||||
|
<button
|
||||||
|
key={p.label}
|
||||||
|
onClick={() => {
|
||||||
|
setUseCustom(false);
|
||||||
|
setPresetIdx(i);
|
||||||
|
}}
|
||||||
|
className={
|
||||||
|
!useCustom && presetIdx === i
|
||||||
|
? "rounded-full bg-emerald-700 px-3 py-1 text-white"
|
||||||
|
: "rounded-full border border-zinc-700 px-3 py-1 text-zinc-300 hover:border-zinc-500"
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{p.label}
|
||||||
|
</button>
|
||||||
|
))}
|
||||||
|
<label
|
||||||
|
className={
|
||||||
|
useCustom
|
||||||
|
? "flex items-center gap-1 rounded-full bg-emerald-700 px-3 py-1 text-white"
|
||||||
|
: "flex items-center gap-1 rounded-full border border-zinc-700 px-3 py-1 text-zinc-300 hover:border-zinc-500"
|
||||||
|
}
|
||||||
|
>
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={useCustom}
|
||||||
|
onChange={(e) => setUseCustom(e.target.checked)}
|
||||||
|
className="h-3 w-3"
|
||||||
|
/>
|
||||||
|
직접
|
||||||
|
<input
|
||||||
|
type="text"
|
||||||
|
placeholder="1~30, 예: 1,2,3,7"
|
||||||
|
value={customRaw}
|
||||||
|
onChange={(e) => setCustomRaw(e.target.value)}
|
||||||
|
onFocus={() => setUseCustom(true)}
|
||||||
|
className="ml-1 w-28 rounded border border-zinc-700 bg-zinc-900 px-1 py-0.5 text-xs text-zinc-100"
|
||||||
|
/>
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{err && <div className="mb-3 text-xs text-red-400">에러: {err}</div>}
|
||||||
|
|
||||||
|
{pred?.found ? (
|
||||||
|
<div>
|
||||||
|
<div className="mb-2 text-xs text-zinc-500">
|
||||||
|
기준일 {pred.base_date} · 기준종가{" "}
|
||||||
|
{pred.base_close != null ? pred.base_close.toLocaleString() : "-"}
|
||||||
|
</div>
|
||||||
|
<table className="w-full text-left text-sm">
|
||||||
|
<thead className="text-xs text-zinc-500">
|
||||||
|
<tr>
|
||||||
|
<th className="py-1">+거래일</th>
|
||||||
|
<th>매칭일</th>
|
||||||
|
<th>방향</th>
|
||||||
|
<th>P(up/flat/down)</th>
|
||||||
|
<th>기대수익</th>
|
||||||
|
<th>예측 종가</th>
|
||||||
|
<th>q10~q90</th>
|
||||||
|
</tr>
|
||||||
|
</thead>
|
||||||
|
<tbody className="divide-y divide-zinc-800">
|
||||||
|
{steps.map((s) => (
|
||||||
|
<tr key={s.horizon}>
|
||||||
|
<td className="py-2">+{s.horizon}</td>
|
||||||
|
<td className="text-xs text-zinc-400">{s.target_date}</td>
|
||||||
|
<td>
|
||||||
|
<span
|
||||||
|
className={
|
||||||
|
s.direction === "up"
|
||||||
|
? "text-emerald-400"
|
||||||
|
: s.direction === "down"
|
||||||
|
? "text-red-400"
|
||||||
|
: "text-zinc-300"
|
||||||
|
}
|
||||||
|
>
|
||||||
|
{s.direction}
|
||||||
|
</span>
|
||||||
|
</td>
|
||||||
|
<td className="text-xs text-zinc-300">
|
||||||
|
{fmtPct(s.prob_up)} / {fmtPct(s.prob_flat)} / {fmtPct(s.prob_down)}
|
||||||
|
</td>
|
||||||
|
<td>{fmtSignedPct(s.expected_return)}</td>
|
||||||
|
<td>{s.point_close != null ? s.point_close.toLocaleString() : "-"}</td>
|
||||||
|
<td className="text-xs text-zinc-400">
|
||||||
|
{s.ci_low != null ? s.ci_low.toLocaleString() : "-"} ~{" "}
|
||||||
|
{s.ci_high != null ? s.ci_high.toLocaleString() : "-"}
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
))}
|
||||||
|
</tbody>
|
||||||
|
</table>
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
아직 예측이 없습니다. <b>예상차트 보기</b> 버튼을 누르면 1·3·5거래일 후 예측을 생성하고
|
||||||
|
차트에 점선으로 이어 붙입니다.
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function fmtPct(v: number | null): string {
|
||||||
|
if (v == null) return "-";
|
||||||
|
return `${(v * 100).toFixed(0)}%`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function fmtSignedPct(v: number | null): string {
|
||||||
|
if (v == null) return "-";
|
||||||
|
const pct = v * 100;
|
||||||
|
const sign = pct >= 0 ? "+" : "";
|
||||||
|
return `${sign}${pct.toFixed(2)}%`;
|
||||||
|
}
|
||||||
93
web/components/SearchBox.tsx
Normal file
93
web/components/SearchBox.tsx
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import Link from "next/link";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
import { api, type Symbol } from "../lib/api";
|
||||||
|
|
||||||
|
export function SearchBox({ autoFocus = false }: { autoFocus?: boolean }) {
|
||||||
|
const [q, setQ] = useState("");
|
||||||
|
const [items, setItems] = useState<Symbol[]>([]);
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [err, setErr] = useState<string | null>(null);
|
||||||
|
const [seedOnly, setSeedOnly] = useState(false);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const term = q.trim();
|
||||||
|
if (!term) {
|
||||||
|
setItems([]);
|
||||||
|
setErr(null);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setLoading(true);
|
||||||
|
setErr(null);
|
||||||
|
const handle = setTimeout(async () => {
|
||||||
|
try {
|
||||||
|
const r = await api.search(term, seedOnly, 15);
|
||||||
|
setItems(r.items);
|
||||||
|
} catch (e) {
|
||||||
|
setErr(e instanceof Error ? e.message : String(e));
|
||||||
|
setItems([]);
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
}, 200);
|
||||||
|
return () => clearTimeout(handle);
|
||||||
|
}, [q, seedOnly]);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full">
|
||||||
|
<div className="flex items-center gap-3">
|
||||||
|
<input
|
||||||
|
autoFocus={autoFocus}
|
||||||
|
value={q}
|
||||||
|
onChange={(e) => setQ(e.target.value)}
|
||||||
|
placeholder="종목명 또는 코드 (예: 삼성, 005930)"
|
||||||
|
className="w-full rounded-md border border-zinc-700 bg-zinc-900 px-4 py-3 text-base outline-none focus:border-zinc-500"
|
||||||
|
/>
|
||||||
|
<label className="flex shrink-0 items-center gap-2 text-xs text-zinc-400">
|
||||||
|
<input
|
||||||
|
type="checkbox"
|
||||||
|
checked={seedOnly}
|
||||||
|
onChange={(e) => setSeedOnly(e.target.checked)}
|
||||||
|
/>
|
||||||
|
학습대상 10종목만
|
||||||
|
</label>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="mt-2 min-h-[1.5rem] text-xs text-zinc-500">
|
||||||
|
{loading && "검색중…"}
|
||||||
|
{err && <span className="text-red-400">에러: {err}</span>}
|
||||||
|
{!loading && !err && q && items.length === 0 && "검색 결과 없음"}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{items.length > 0 && (
|
||||||
|
<ul className="mt-2 divide-y divide-zinc-800 rounded-md border border-zinc-800 bg-zinc-900/40">
|
||||||
|
{items.map((it) => (
|
||||||
|
<li key={it.code}>
|
||||||
|
<Link
|
||||||
|
href={`/${it.code}`}
|
||||||
|
className="flex items-center justify-between px-4 py-3 text-sm hover:bg-zinc-800/70"
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
<div className="font-medium text-zinc-100">
|
||||||
|
{it.name}
|
||||||
|
{it.is_seed && (
|
||||||
|
<span className="ml-2 rounded-full bg-emerald-900/60 px-2 py-0.5 text-[10px] font-semibold text-emerald-300">
|
||||||
|
SEED
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="text-xs text-zinc-500">
|
||||||
|
{it.code} · {it.market}
|
||||||
|
{it.sector ? ` · ${it.sector}` : ""}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<span className="text-zinc-500">→</span>
|
||||||
|
</Link>
|
||||||
|
</li>
|
||||||
|
))}
|
||||||
|
</ul>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
210
web/components/StockChart.tsx
Normal file
210
web/components/StockChart.tsx
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useEffect, useRef } from "react";
|
||||||
|
import {
|
||||||
|
createChart,
|
||||||
|
type CandlestickData,
|
||||||
|
type IChartApi,
|
||||||
|
type ISeriesApi,
|
||||||
|
type LineData,
|
||||||
|
type UTCTimestamp,
|
||||||
|
} from "lightweight-charts";
|
||||||
|
import type { ChartPayload, LatestPredictionResponse } from "../lib/api";
|
||||||
|
|
||||||
|
type Props = {
|
||||||
|
chart: ChartPayload;
|
||||||
|
prediction?: LatestPredictionResponse | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
// 'YYYY-MM-DD' 또는 'YYYY-MM-DDTHH:MM:SS' (KST naive, 백엔드가 +09:00 시각의 wall-clock 을
|
||||||
|
// 그대로 ISO 로 직렬화) 를 UTCTimestamp 로. lightweight-charts 는 timestamp 가 UTC 라고
|
||||||
|
// 가정하지만, 우리는 KST wall-clock 을 UTC 인 척 넣는다 — timeScale 의 표시도 KST 그대로
|
||||||
|
// 나와서 한국 사용자에겐 가장 직관적.
|
||||||
|
function isoToUtcTs(s: string): UTCTimestamp {
|
||||||
|
if (s.length <= 10) {
|
||||||
|
return (Date.UTC(
|
||||||
|
Number(s.slice(0, 4)),
|
||||||
|
Number(s.slice(5, 7)) - 1,
|
||||||
|
Number(s.slice(8, 10)),
|
||||||
|
) / 1000) as UTCTimestamp;
|
||||||
|
}
|
||||||
|
// datetime: YYYY-MM-DDTHH:MM:SS
|
||||||
|
return (Date.UTC(
|
||||||
|
Number(s.slice(0, 4)),
|
||||||
|
Number(s.slice(5, 7)) - 1,
|
||||||
|
Number(s.slice(8, 10)),
|
||||||
|
Number(s.slice(11, 13)),
|
||||||
|
Number(s.slice(14, 16)),
|
||||||
|
Number(s.slice(17, 19) || "0"),
|
||||||
|
) / 1000) as UTCTimestamp;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function StockChart({ chart, prediction }: Props) {
|
||||||
|
const containerRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
const chartRef = useRef<IChartApi | null>(null);
|
||||||
|
const candleRef = useRef<ISeriesApi<"Candlestick"> | null>(null);
|
||||||
|
const predRef = useRef<ISeriesApi<"Line"> | null>(null);
|
||||||
|
const predLowRef = useRef<ISeriesApi<"Line"> | null>(null);
|
||||||
|
const predHighRef = useRef<ISeriesApi<"Line"> | null>(null);
|
||||||
|
|
||||||
|
const isIntraday = chart.interval === "10m";
|
||||||
|
|
||||||
|
// create chart once (interval 바뀌면 timeVisible 토글 위해 의존성에 isIntraday 포함 — 재생성)
|
||||||
|
useEffect(() => {
|
||||||
|
if (!containerRef.current) return;
|
||||||
|
const c = createChart(containerRef.current, {
|
||||||
|
layout: {
|
||||||
|
background: { color: "transparent" },
|
||||||
|
textColor: "#cbd5e1",
|
||||||
|
},
|
||||||
|
grid: {
|
||||||
|
vertLines: { color: "#1f2937" },
|
||||||
|
horzLines: { color: "#1f2937" },
|
||||||
|
},
|
||||||
|
rightPriceScale: { borderColor: "#374151" },
|
||||||
|
timeScale: {
|
||||||
|
borderColor: "#374151",
|
||||||
|
timeVisible: isIntraday,
|
||||||
|
secondsVisible: false,
|
||||||
|
},
|
||||||
|
autoSize: true,
|
||||||
|
});
|
||||||
|
const candle = c.addCandlestickSeries({
|
||||||
|
upColor: "#22c55e",
|
||||||
|
downColor: "#ef4444",
|
||||||
|
borderUpColor: "#22c55e",
|
||||||
|
borderDownColor: "#ef4444",
|
||||||
|
wickUpColor: "#22c55e",
|
||||||
|
wickDownColor: "#ef4444",
|
||||||
|
});
|
||||||
|
chartRef.current = c;
|
||||||
|
candleRef.current = candle;
|
||||||
|
return () => {
|
||||||
|
c.remove();
|
||||||
|
chartRef.current = null;
|
||||||
|
candleRef.current = null;
|
||||||
|
predRef.current = null;
|
||||||
|
predLowRef.current = null;
|
||||||
|
predHighRef.current = null;
|
||||||
|
};
|
||||||
|
}, [isIntraday]);
|
||||||
|
|
||||||
|
// push candle data + today marker
|
||||||
|
useEffect(() => {
|
||||||
|
if (!candleRef.current) return;
|
||||||
|
const data: CandlestickData[] = chart.ohlcv
|
||||||
|
.filter((p) => p.open !== null && p.high !== null && p.low !== null && p.close !== null)
|
||||||
|
.map((p) => ({
|
||||||
|
time: isoToUtcTs(p.date),
|
||||||
|
open: p.open as number,
|
||||||
|
high: p.high as number,
|
||||||
|
low: p.low as number,
|
||||||
|
close: p.close as number,
|
||||||
|
}));
|
||||||
|
candleRef.current.setData(data);
|
||||||
|
// 오늘 표시는 차트 본체 위가 아니라 컨테이너 아래 캡션 (return JSX) 으로 옮김.
|
||||||
|
// lightweight-charts 의 timeScale tick 자체에 라벨을 끼울 공식 API 가 없어서,
|
||||||
|
// 시각적으로 동일한 위치 (시간축 바로 아래) 에 별도 div 로 렌더.
|
||||||
|
chartRef.current?.timeScale().fitContent();
|
||||||
|
}, [chart, isIntraday]);
|
||||||
|
|
||||||
|
// push prediction overlay (10분봉에서는 표시 안 함 — 예측은 일봉 기준)
|
||||||
|
useEffect(() => {
|
||||||
|
if (!chartRef.current) return;
|
||||||
|
if (predRef.current) {
|
||||||
|
chartRef.current.removeSeries(predRef.current);
|
||||||
|
predRef.current = null;
|
||||||
|
}
|
||||||
|
if (predLowRef.current) {
|
||||||
|
chartRef.current.removeSeries(predLowRef.current);
|
||||||
|
predLowRef.current = null;
|
||||||
|
}
|
||||||
|
if (predHighRef.current) {
|
||||||
|
chartRef.current.removeSeries(predHighRef.current);
|
||||||
|
predHighRef.current = null;
|
||||||
|
}
|
||||||
|
if (isIntraday) return;
|
||||||
|
if (!prediction || !prediction.found || !prediction.steps?.length) return;
|
||||||
|
const baseDate = prediction.base_date!;
|
||||||
|
const baseClose = prediction.base_close;
|
||||||
|
if (!baseClose) return;
|
||||||
|
const sorted = [...prediction.steps].sort((a, b) => a.horizon - b.horizon);
|
||||||
|
|
||||||
|
const med: LineData[] = [
|
||||||
|
{ time: isoToUtcTs(baseDate), value: baseClose },
|
||||||
|
...sorted
|
||||||
|
.filter((s) => s.point_close !== null)
|
||||||
|
.map((s) => ({ time: isoToUtcTs(s.target_date), value: s.point_close as number })),
|
||||||
|
];
|
||||||
|
const lo: LineData[] = [
|
||||||
|
{ time: isoToUtcTs(baseDate), value: baseClose },
|
||||||
|
...sorted
|
||||||
|
.filter((s) => s.ci_low !== null)
|
||||||
|
.map((s) => ({ time: isoToUtcTs(s.target_date), value: s.ci_low as number })),
|
||||||
|
];
|
||||||
|
const hi: LineData[] = [
|
||||||
|
{ time: isoToUtcTs(baseDate), value: baseClose },
|
||||||
|
...sorted
|
||||||
|
.filter((s) => s.ci_high !== null)
|
||||||
|
.map((s) => ({ time: isoToUtcTs(s.target_date), value: s.ci_high as number })),
|
||||||
|
];
|
||||||
|
|
||||||
|
const medLine = chartRef.current.addLineSeries({
|
||||||
|
color: "#a78bfa",
|
||||||
|
lineWidth: 2,
|
||||||
|
lineStyle: 2, // dashed
|
||||||
|
priceLineVisible: false,
|
||||||
|
lastValueVisible: true,
|
||||||
|
title: "예측 median",
|
||||||
|
});
|
||||||
|
medLine.setData(med);
|
||||||
|
const loLine = chartRef.current.addLineSeries({
|
||||||
|
color: "#7c3aed",
|
||||||
|
lineWidth: 1,
|
||||||
|
lineStyle: 1,
|
||||||
|
priceLineVisible: false,
|
||||||
|
lastValueVisible: false,
|
||||||
|
title: "q10",
|
||||||
|
});
|
||||||
|
loLine.setData(lo);
|
||||||
|
const hiLine = chartRef.current.addLineSeries({
|
||||||
|
color: "#7c3aed",
|
||||||
|
lineWidth: 1,
|
||||||
|
lineStyle: 1,
|
||||||
|
priceLineVisible: false,
|
||||||
|
lastValueVisible: false,
|
||||||
|
title: "q90",
|
||||||
|
});
|
||||||
|
hiLine.setData(hi);
|
||||||
|
predRef.current = medLine;
|
||||||
|
predLowRef.current = loLine;
|
||||||
|
predHighRef.current = hiLine;
|
||||||
|
chartRef.current.timeScale().fitContent();
|
||||||
|
}, [prediction, isIntraday]);
|
||||||
|
|
||||||
|
// 오늘 라벨 — 차트 본체에 마커 대신 시간축 바로 아래에 작은 캡션으로.
|
||||||
|
// 10분봉은 데이터 자체가 오늘 하루라 굳이 라벨 불필요.
|
||||||
|
const todayLabel =
|
||||||
|
!isIntraday && chart.today
|
||||||
|
? new Date(chart.today + "T00:00:00").toLocaleDateString("ko-KR", {
|
||||||
|
year: "numeric",
|
||||||
|
month: "2-digit",
|
||||||
|
day: "2-digit",
|
||||||
|
weekday: "short",
|
||||||
|
})
|
||||||
|
: null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="w-full rounded-md border border-zinc-800 bg-zinc-900/30 p-2">
|
||||||
|
<div className="h-[460px] w-full">
|
||||||
|
<div ref={containerRef} className="h-full w-full" />
|
||||||
|
</div>
|
||||||
|
{todayLabel && (
|
||||||
|
<div className="mt-1 flex items-center justify-end gap-2 px-2 text-xs text-zinc-400">
|
||||||
|
<span className="inline-block h-2 w-2 rounded-full bg-amber-400" />
|
||||||
|
<span>오늘 · {todayLabel}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
212
web/lib/api.ts
Normal file
212
web/lib/api.ts
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
// Backend API client.
|
||||||
|
//
|
||||||
|
// API 베이스 해석 우선순위:
|
||||||
|
// 1) NEXT_PUBLIC_API_BASE 가 localhost/127.0.0.1 이 아닌 명시값 → 그대로 사용
|
||||||
|
// (예: 프로덕션 https://api.example.com)
|
||||||
|
// 2) 브라우저 환경 → window.location.hostname:8000 (LAN 접속도 자동 대응)
|
||||||
|
// 3) SSR 폴백 → http://localhost:8000
|
||||||
|
//
|
||||||
|
// docker-compose 가 NEXT_PUBLIC_API_BASE=http://localhost:8000 을 주입하는 경우가 흔한데,
|
||||||
|
// LAN 의 다른 PC 에서 http://<host>:3000 으로 접속하면 inline 된 localhost 가 그쪽 PC 의
|
||||||
|
// localhost 를 가리켜 깨진다. 그래서 localhost/127.0.0.1 값은 신뢰하지 않고 페이지 host 로
|
||||||
|
// 폴백.
|
||||||
|
|
||||||
|
function resolveApiBase(): string {
|
||||||
|
const raw = process.env.NEXT_PUBLIC_API_BASE;
|
||||||
|
const env = raw && raw.length > 0 ? raw.replace(/\/$/, "") : "";
|
||||||
|
const envIsLocal = !env || /\/\/(localhost|127\.0\.0\.1)(?::|$)/.test(env);
|
||||||
|
if (typeof window !== "undefined") {
|
||||||
|
if (envIsLocal) {
|
||||||
|
return `${window.location.protocol}//${window.location.hostname}:8000`;
|
||||||
|
}
|
||||||
|
return env;
|
||||||
|
}
|
||||||
|
// SSR
|
||||||
|
return env || "http://localhost:8000";
|
||||||
|
}
|
||||||
|
|
||||||
|
export const API_BASE = resolveApiBase();
|
||||||
|
|
||||||
|
export type Symbol = {
|
||||||
|
code: string;
|
||||||
|
name: string;
|
||||||
|
market: string;
|
||||||
|
sector: string | null;
|
||||||
|
is_seed: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type SymbolSearch = {
|
||||||
|
q: string;
|
||||||
|
count: number;
|
||||||
|
items: Symbol[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type OhlcvPoint = {
|
||||||
|
// 1d/1w/1mo: 'YYYY-MM-DD' / 10m: 'YYYY-MM-DDTHH:MM:SS' (KST naive ISO)
|
||||||
|
date: string;
|
||||||
|
open: number | null;
|
||||||
|
high: number | null;
|
||||||
|
low: number | null;
|
||||||
|
close: number | null;
|
||||||
|
volume: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ChartInterval = "10m" | "1d" | "1w" | "1mo";
|
||||||
|
|
||||||
|
export type SentimentPoint = {
|
||||||
|
date: string;
|
||||||
|
n_articles: number;
|
||||||
|
mean_score: number | null;
|
||||||
|
weighted_score: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type TradingValuePoint = {
|
||||||
|
date: string;
|
||||||
|
foreign_net: number | null;
|
||||||
|
institution_net: number | null;
|
||||||
|
individual_net: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ChartPayload = {
|
||||||
|
code: string;
|
||||||
|
name: string;
|
||||||
|
market: string;
|
||||||
|
interval: ChartInterval;
|
||||||
|
intraday_status: string | null;
|
||||||
|
range: { from: string; to: string };
|
||||||
|
today: string;
|
||||||
|
ohlcv: OhlcvPoint[];
|
||||||
|
sentiment: SentimentPoint[];
|
||||||
|
trading_value: TradingValuePoint[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type PredictionStep = {
|
||||||
|
horizon: number;
|
||||||
|
target_idx?: number;
|
||||||
|
point_close: number;
|
||||||
|
ci_low: number;
|
||||||
|
ci_high: number;
|
||||||
|
prob_up: number;
|
||||||
|
prob_flat: number;
|
||||||
|
prob_down: number;
|
||||||
|
direction: "up" | "flat" | "down";
|
||||||
|
expected_return: number;
|
||||||
|
target_date?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type PredictResponse = {
|
||||||
|
code: string;
|
||||||
|
base_date: string;
|
||||||
|
base_close: number;
|
||||||
|
sources_used: string[];
|
||||||
|
steps: PredictionStep[];
|
||||||
|
saved_prediction_ids: number[];
|
||||||
|
saved_shadow_ids?: { chronos: number[]; lgbm: number[] };
|
||||||
|
user_triggered: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LatestPredictionStep = {
|
||||||
|
predicted_at: string | null;
|
||||||
|
target_date: string;
|
||||||
|
horizon: number;
|
||||||
|
direction: "up" | "flat" | "down" | string;
|
||||||
|
prob_up: number | null;
|
||||||
|
prob_flat: number | null;
|
||||||
|
prob_down: number | null;
|
||||||
|
expected_return: number | null;
|
||||||
|
point_close: number | null;
|
||||||
|
ci_low: number | null;
|
||||||
|
ci_high: number | null;
|
||||||
|
user_triggered: boolean;
|
||||||
|
features_snapshot: unknown;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type LatestPredictionResponse = {
|
||||||
|
code: string;
|
||||||
|
name?: string;
|
||||||
|
found: boolean;
|
||||||
|
base_date?: string;
|
||||||
|
base_close?: number | null;
|
||||||
|
steps: LatestPredictionStep[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type MetricsRow = {
|
||||||
|
model: string;
|
||||||
|
horizon: number;
|
||||||
|
n: number;
|
||||||
|
hit_rate: number | null;
|
||||||
|
mae: number | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type MetricsResponse = {
|
||||||
|
code?: string;
|
||||||
|
name?: string;
|
||||||
|
window_days: number;
|
||||||
|
range: { from: string; to: string };
|
||||||
|
by_model_horizon: MetricsRow[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type NewsItem = {
|
||||||
|
source: string;
|
||||||
|
published_at: string | null;
|
||||||
|
title: string;
|
||||||
|
url: string;
|
||||||
|
sentiment_score: number | null;
|
||||||
|
sentiment_label: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type NewsResponse = {
|
||||||
|
code: string;
|
||||||
|
name: string;
|
||||||
|
count: number;
|
||||||
|
items: NewsItem[];
|
||||||
|
};
|
||||||
|
|
||||||
|
async function getJson<T>(path: string, init?: RequestInit): Promise<T> {
|
||||||
|
const res = await fetch(`${API_BASE}${path}`, {
|
||||||
|
...init,
|
||||||
|
headers: {
|
||||||
|
Accept: "application/json",
|
||||||
|
...(init?.headers ?? {}),
|
||||||
|
},
|
||||||
|
cache: "no-store",
|
||||||
|
});
|
||||||
|
if (!res.ok) {
|
||||||
|
const text = await res.text().catch(() => "");
|
||||||
|
throw new Error(`API ${path} → ${res.status} ${text || res.statusText}`);
|
||||||
|
}
|
||||||
|
return (await res.json()) as T;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const api = {
|
||||||
|
search: (q: string, seedOnly = false, limit = 20) =>
|
||||||
|
getJson<SymbolSearch>(
|
||||||
|
`/api/symbols/search?q=${encodeURIComponent(q)}&limit=${limit}&seed_only=${seedOnly}`,
|
||||||
|
),
|
||||||
|
getSymbol: (code: string) => getJson<Symbol>(`/api/symbols/${encodeURIComponent(code)}`),
|
||||||
|
getChart: (code: string, days = 180, interval: ChartInterval = "1d") =>
|
||||||
|
getJson<ChartPayload>(
|
||||||
|
`/api/chart/${encodeURIComponent(code)}?days=${days}&interval=${encodeURIComponent(interval)}`,
|
||||||
|
),
|
||||||
|
predict: (code: string, horizons: string | number[] = "1,3,5") => {
|
||||||
|
const h = Array.isArray(horizons) ? horizons.join(",") : horizons;
|
||||||
|
return getJson<PredictResponse>(
|
||||||
|
`/api/predict/${encodeURIComponent(code)}?horizons=${encodeURIComponent(h)}`,
|
||||||
|
{ method: "POST" },
|
||||||
|
);
|
||||||
|
},
|
||||||
|
latestPrediction: (code: string) =>
|
||||||
|
getJson<LatestPredictionResponse>(`/api/predict/${encodeURIComponent(code)}/latest`),
|
||||||
|
metrics: (code: string, windowDays = 30) =>
|
||||||
|
getJson<MetricsResponse>(
|
||||||
|
`/api/metrics/${encodeURIComponent(code)}?window_days=${windowDays}`,
|
||||||
|
),
|
||||||
|
overallMetrics: (windowDays = 30) =>
|
||||||
|
getJson<MetricsResponse>(`/api/metrics?window_days=${windowDays}`),
|
||||||
|
news: (code: string, limit = 20, source?: string) =>
|
||||||
|
getJson<NewsResponse>(
|
||||||
|
`/api/news/${encodeURIComponent(code)}?limit=${limit}${
|
||||||
|
source ? `&source=${encodeURIComponent(source)}` : ""
|
||||||
|
}`,
|
||||||
|
),
|
||||||
|
};
|
||||||
2
web/next-env.d.ts
vendored
2
web/next-env.d.ts
vendored
@@ -2,4 +2,4 @@
|
|||||||
/// <reference types="next/image-types/global" />
|
/// <reference types="next/image-types/global" />
|
||||||
|
|
||||||
// NOTE: This file should not be edited
|
// NOTE: This file should not be edited
|
||||||
// see https://nextjs.org/docs/basic-features/typescript for more information.
|
// see https://nextjs.org/docs/app/building-your-application/configuring/typescript for more information.
|
||||||
|
|||||||
5853
web/package-lock.json
generated
Normal file
5853
web/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
@@ -6,10 +6,12 @@
|
|||||||
"dev": "next dev -p 3000 -H 0.0.0.0",
|
"dev": "next dev -p 3000 -H 0.0.0.0",
|
||||||
"build": "next build",
|
"build": "next build",
|
||||||
"start": "next start -p 3000 -H 0.0.0.0",
|
"start": "next start -p 3000 -H 0.0.0.0",
|
||||||
"lint": "next lint"
|
"lint": "next lint",
|
||||||
|
"typecheck": "tsc --noEmit",
|
||||||
|
"check": "npm run typecheck && npm run lint"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"next": "14.2.3",
|
"next": "^14.2.33",
|
||||||
"react": "18.3.1",
|
"react": "18.3.1",
|
||||||
"react-dom": "18.3.1",
|
"react-dom": "18.3.1",
|
||||||
"lightweight-charts": "4.1.7"
|
"lightweight-charts": "4.1.7"
|
||||||
@@ -21,6 +23,8 @@
|
|||||||
"typescript": "5.4.5",
|
"typescript": "5.4.5",
|
||||||
"tailwindcss": "3.4.4",
|
"tailwindcss": "3.4.4",
|
||||||
"postcss": "8.4.38",
|
"postcss": "8.4.38",
|
||||||
"autoprefixer": "10.4.19"
|
"autoprefixer": "10.4.19",
|
||||||
|
"eslint": "8.57.0",
|
||||||
|
"eslint-config-next": "^14.2.33"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,6 +13,9 @@
|
|||||||
"isolatedModules": true,
|
"isolatedModules": true,
|
||||||
"jsx": "preserve",
|
"jsx": "preserve",
|
||||||
"incremental": true,
|
"incremental": true,
|
||||||
|
"forceConsistentCasingInFileNames": true,
|
||||||
|
"noFallthroughCasesInSwitch": true,
|
||||||
|
"noImplicitOverride": true,
|
||||||
"plugins": [{ "name": "next" }],
|
"plugins": [{ "name": "next" }],
|
||||||
"paths": { "@/*": ["./*"] }
|
"paths": { "@/*": ["./*"] }
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user