Compare commits
3 Commits
bc016ab76d
...
5e6ce11491
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e6ce11491 | ||
|
|
0af556396e | ||
|
|
f84b460e54 |
10
README.md
10
README.md
@@ -155,9 +155,13 @@ stock_chart_site/
|
|||||||
|
|
||||||
## 동작 모델 메모
|
## 동작 모델 메모
|
||||||
|
|
||||||
- 예측 트리거: 사용자가 "예상차트 보기" 누른 종목에 대해 즉시 inference. 결과는 `predictions(user_triggered=TRUE)` 로 저장.
|
- 예측 트리거: 사용자가 "예상차트 보기" 누른 종목에 대해 즉시 inference. 결과는 세 종류 행으로 적재:
|
||||||
- 매칭 배치: 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30, 종가 확정 후 16:00 ~ 16:30 KST 사이) `user_triggered=TRUE` 인 예측 중 `target_date == 오늘 거래일`인 행들에 대해 실제 종가/방향과 매칭 → `prediction_outcomes` 적재. 주말/공휴일이면 다음 거래일로 이월.
|
- `model='ensemble'` (user_triggered=TRUE) — UI 가 표시하는 최종 예측
|
||||||
- 주간 02:00 (일요일): 종목/모델별 최근 30일 hit rate 기반으로 앙상블 가중치를 자동 보정. hit rate가 임계 미만이면 LGBM 재학습.
|
- `model='chronos'` (user_triggered=FALSE, shadow) — Chronos 단독 성능 추적용
|
||||||
|
- `model='lgbm'` (user_triggered=FALSE, shadow) — LGBM 단독 성능 추적용
|
||||||
|
- 매칭 배치: 평일 16:30 KST. `target_date <= today AND outcomes 미존재` 인 모든 행에 대해 `target_date` 이상 `today` 이하 범위의 **최초 거래일 종가**를 actual_close 로 사용 → 주말/공휴일 자동 이월. shadow 행도 함께 매칭됨.
|
||||||
|
- 주간 02:00 (일요일): 시드 10종목 × horizons LGBM 재학습. 최근 30일 prediction_outcomes 의 chronos vs lgbm hit_rate 비교 → `w_chronos = clamp(0.1, hr_c/(hr_c+hr_l), 0.9)` 공식으로 `ensemble_weights` upsert. 모델별 표본이 10 미만이면 기본값(0.6/0.4) 유지.
|
||||||
|
- DB bootstrap: 백엔드 첫 부팅 시 lifespan 에서 idempotent migration + symbols 시드(비어있을 때만 pykrx 전 종목 적재) 자동 수행. `BOOTSTRAP_DISABLED=1` 로 비활성화 가능.
|
||||||
|
|
||||||
## 안전/한계
|
## 안전/한계
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
from app.api.chart import router as chart_router
|
from app.api.chart import router as chart_router
|
||||||
from app.api.metrics import router as metrics_router
|
from app.api.metrics import router as metrics_router
|
||||||
@@ -13,7 +15,7 @@ from app.api.predict import router as predict_router
|
|||||||
from app.api.refresh import router as refresh_router
|
from app.api.refresh import router as refresh_router
|
||||||
from app.api.symbols import router as symbols_router
|
from app.api.symbols import router as symbols_router
|
||||||
from app.config import settings
|
from app.config import settings
|
||||||
from app.db.connection import ping as db_ping
|
from app.db.connection import get_engine, ping as db_ping
|
||||||
from app.fetch import dart as dart_mod
|
from app.fetch import dart as dart_mod
|
||||||
from app.fetch import kis as kis_mod
|
from app.fetch import kis as kis_mod
|
||||||
from app.pipelines.scheduler import shutdown_scheduler, start_scheduler
|
from app.pipelines.scheduler import shutdown_scheduler, start_scheduler
|
||||||
@@ -25,13 +27,55 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _bootstrap_db() -> None:
|
||||||
|
"""첫 부팅 자동화:
|
||||||
|
1) migrations/*.sql idempotent 적용 (timescale/pgvector 확장 + 스키마)
|
||||||
|
2) symbols 테이블 비어있으면 pykrx 로 전 종목 시드 (SEED 10 마크 포함)
|
||||||
|
|
||||||
|
BOOTSTRAP_DISABLED=1 이면 스킵 (테스트/CI 용). 어떤 단계든 실패해도 서버는
|
||||||
|
뜬다 — /health/db 가 진단을 알려준다.
|
||||||
|
"""
|
||||||
|
if os.environ.get("BOOTSTRAP_DISABLED") == "1":
|
||||||
|
logger.info("bootstrap skipped (BOOTSTRAP_DISABLED=1)")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 1) migrations
|
||||||
|
try:
|
||||||
|
from app.db.migrate import apply_all
|
||||||
|
res = apply_all()
|
||||||
|
logger.info("bootstrap migrate: %s", res)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("bootstrap migrate failed")
|
||||||
|
return # 스키마 없으면 시드 불가
|
||||||
|
|
||||||
|
# 2) symbols 시드 (비어있을 때만 — pykrx 호출이 비싸므로 항상 돌리지 않음)
|
||||||
|
try:
|
||||||
|
eng = get_engine()
|
||||||
|
with eng.connect() as conn:
|
||||||
|
row = conn.execute(text("SELECT COUNT(*) FROM symbols")).first()
|
||||||
|
count = int(row[0]) if row else 0
|
||||||
|
if count == 0:
|
||||||
|
logger.info("symbols empty — running initial seed")
|
||||||
|
from app.fetch.symbols_seed import seed_symbols
|
||||||
|
report = seed_symbols()
|
||||||
|
logger.info("bootstrap seed_symbols: %s", report)
|
||||||
|
else:
|
||||||
|
logger.info("symbols already populated (count=%d) — skip seed", count)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("bootstrap seed_symbols failed")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(_: FastAPI):
|
async def lifespan(_: FastAPI):
|
||||||
|
_bootstrap_db()
|
||||||
# 스케줄러는 옵션. CI/테스트에서 disable 하고 싶으면 SCHEDULER_DISABLED 같은 env 추가 가능.
|
# 스케줄러는 옵션. CI/테스트에서 disable 하고 싶으면 SCHEDULER_DISABLED 같은 env 추가 가능.
|
||||||
try:
|
if os.environ.get("SCHEDULER_DISABLED") == "1":
|
||||||
start_scheduler()
|
logger.info("scheduler skipped (SCHEDULER_DISABLED=1)")
|
||||||
except Exception: # noqa: BLE001
|
else:
|
||||||
logger.exception("scheduler start failed")
|
try:
|
||||||
|
start_scheduler()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
logger.exception("scheduler start failed")
|
||||||
yield
|
yield
|
||||||
shutdown_scheduler()
|
shutdown_scheduler()
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ Chronos 도 실패하면 LGBM 단독으로 진행.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -53,6 +53,10 @@ class EnsemblePrediction:
|
|||||||
horizons: list[int]
|
horizons: list[int]
|
||||||
steps: list[EnsembleStep]
|
steps: list[EnsembleStep]
|
||||||
sources_used: list[str]
|
sources_used: list[str]
|
||||||
|
# shadow 저장용 원본 출력 (predict_one.py 가 ensemble + chronos 단독 + lgbm 단독
|
||||||
|
# 3 종을 predictions 에 적재해서 retrain_weekly 가 모델별 hit_rate 비교 가능하게 함).
|
||||||
|
chronos_raw: ChronosForecast | None = None
|
||||||
|
lgbm_raw: dict[int, LgbmForecast] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
def _chronos_direction(samples: list[list[float]], base_close: float, horizon: int) -> tuple[float, float, float]:
|
def _chronos_direction(samples: list[list[float]], base_close: float, horizon: int) -> tuple[float, float, float]:
|
||||||
@@ -90,12 +94,14 @@ def predict(code: str, *, horizons: tuple[int, ...] = (1, 3, 5)) -> EnsemblePred
|
|||||||
logger.warning("chronos forecast failed for %s: %s", code, exc)
|
logger.warning("chronos forecast failed for %s: %s", code, exc)
|
||||||
|
|
||||||
steps: list[EnsembleStep] = []
|
steps: list[EnsembleStep] = []
|
||||||
|
lgbm_raw: dict[int, LgbmForecast] = {}
|
||||||
for h in horizons:
|
for h in horizons:
|
||||||
lf: LgbmForecast | None = None
|
lf: LgbmForecast | None = None
|
||||||
try:
|
try:
|
||||||
lf = lgbm_predict(code, h)
|
lf = lgbm_predict(code, h)
|
||||||
if lf is not None:
|
if lf is not None:
|
||||||
sources_used.append(f"lgbm_h{h}")
|
sources_used.append(f"lgbm_h{h}")
|
||||||
|
lgbm_raw[h] = lf
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.warning("lgbm predict failed for %s h=%d: %s", code, h, exc)
|
logger.warning("lgbm predict failed for %s h=%d: %s", code, h, exc)
|
||||||
|
|
||||||
@@ -171,4 +177,6 @@ def predict(code: str, *, horizons: tuple[int, ...] = (1, 3, 5)) -> EnsemblePred
|
|||||||
horizons=list(horizons),
|
horizons=list(horizons),
|
||||||
steps=steps,
|
steps=steps,
|
||||||
sources_used=sources_used,
|
sources_used=sources_used,
|
||||||
|
chronos_raw=cf,
|
||||||
|
lgbm_raw=lgbm_raw,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
"""prediction_outcomes 매칭 배치.
|
"""prediction_outcomes 매칭 배치.
|
||||||
|
|
||||||
평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의
|
평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의
|
||||||
확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, target_date == 오늘인
|
확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, 매칭 미해결 예측을 실제
|
||||||
user_triggered=TRUE 예측을 그 종가와 매칭.
|
종가와 매칭한다.
|
||||||
|
|
||||||
cold-start / 휴장일 대비: 인자로 받은 target_date 의 ohlcv_daily 에 종가가
|
이월/공휴일 정책:
|
||||||
없으면 자연스럽게 skip. 다음 거래일 매칭 잡이 다시 시도하면 그 날짜는
|
target_date 가 calendar date 라서 비거래일이면 ohlcv_daily 에 행이 없다.
|
||||||
여전히 매칭되지 않으므로 (매칭 sql 이 target_date 기준), 영원히 매칭 안되는
|
그래서 `target_date <= today` 인 미해결 행을 전부 후보로 잡고, 각 행마다
|
||||||
잘못된 calendar date 예측은 cleanup CLI 로 별도 정리 가능 (Phase 7).
|
`target_date <= ohlcv_daily.date <= today` 범위의 최초 거래일 종가로
|
||||||
|
매칭한다 (=다음 거래일로 자동 이월).
|
||||||
|
|
||||||
|
shadow prediction 도 같은 방식으로 매칭한다 (user_triggered 필터 없음).
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -29,7 +32,7 @@ FLAT_BAND = 0.003
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatchSummary:
|
class MatchSummary:
|
||||||
target_date: str
|
today: str
|
||||||
candidates: int
|
candidates: int
|
||||||
matched: int
|
matched: int
|
||||||
skipped_no_actual: int
|
skipped_no_actual: int
|
||||||
@@ -44,64 +47,62 @@ def _direction_label(ret: float) -> str:
|
|||||||
return "flat"
|
return "flat"
|
||||||
|
|
||||||
|
|
||||||
def match_for_date(d: date) -> MatchSummary:
|
def match_up_to(today: date) -> MatchSummary:
|
||||||
"""target_date == d 인 user_triggered=TRUE 예측을 매칭."""
|
"""target_date <= today 인 모든 미해결 예측을 매칭.
|
||||||
|
|
||||||
|
각 행마다 ohlcv_daily 에서 target_date 이상, today 이하 범위의 최초
|
||||||
|
거래일 종가를 actual_close 로 사용 — 공휴일/주말 이월 자연 처리.
|
||||||
|
"""
|
||||||
eng = get_engine()
|
eng = get_engine()
|
||||||
with eng.begin() as conn:
|
with eng.begin() as conn:
|
||||||
# 매칭 대상 예측 + 매칭 안 됐는지 확인.
|
|
||||||
candidate_rows = conn.execute(
|
candidate_rows = conn.execute(
|
||||||
text(
|
text(
|
||||||
"""
|
"""
|
||||||
SELECT p.id, p.code, p.base_date, p.horizon, p.point_forecast,
|
SELECT p.id, p.code, p.base_date, p.target_date, p.horizon,
|
||||||
p.direction, p.model
|
p.point_forecast, p.direction, p.model
|
||||||
FROM predictions p
|
FROM predictions p
|
||||||
LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id
|
LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id
|
||||||
WHERE p.target_date = :d
|
WHERE p.target_date <= :today
|
||||||
AND p.user_triggered = TRUE
|
|
||||||
AND po.prediction_id IS NULL
|
AND po.prediction_id IS NULL
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
{"d": d},
|
{"today": today},
|
||||||
).all()
|
).all()
|
||||||
candidates = len(candidate_rows)
|
candidates = len(candidate_rows)
|
||||||
if not candidates:
|
if not candidates:
|
||||||
return MatchSummary(str(d), 0, 0, 0, 0)
|
return MatchSummary(str(today), 0, 0, 0, 0)
|
||||||
|
|
||||||
# 종목별로 actual close 조회 (한번에 batch).
|
|
||||||
codes = list({r[1] for r in candidate_rows})
|
|
||||||
actual_map: dict[tuple[str, date], float] = {}
|
|
||||||
for code in codes:
|
|
||||||
row = conn.execute(
|
|
||||||
text(
|
|
||||||
"SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"
|
|
||||||
),
|
|
||||||
{"c": code, "d": d},
|
|
||||||
).first()
|
|
||||||
if row and row[0] is not None:
|
|
||||||
actual_map[(code, d)] = float(row[0])
|
|
||||||
|
|
||||||
# base_close (각 예측의 base_date 종가) 도 필요 — direction 판정용.
|
|
||||||
base_close_map: dict[tuple[str, date], float] = {}
|
|
||||||
for pid, code, base_date, *_ in candidate_rows:
|
|
||||||
key = (code, base_date)
|
|
||||||
if key in base_close_map:
|
|
||||||
continue
|
|
||||||
row = conn.execute(
|
|
||||||
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
|
||||||
{"c": code, "d": base_date},
|
|
||||||
).first()
|
|
||||||
if row and row[0] is not None:
|
|
||||||
base_close_map[key] = float(row[0])
|
|
||||||
|
|
||||||
matched = 0
|
matched = 0
|
||||||
skipped = 0
|
skipped = 0
|
||||||
already = 0
|
already = 0
|
||||||
for pid, code, base_date, horizon, point_forecast, pred_dir, model in candidate_rows:
|
for pid, code, base_date, target_date, horizon, point_forecast, pred_dir, model in candidate_rows:
|
||||||
actual = actual_map.get((code, d))
|
# 첫 거래일 종가 (target_date <= date <= today)
|
||||||
base_close = base_close_map.get((code, base_date))
|
actual_row = conn.execute(
|
||||||
if actual is None or base_close is None:
|
text(
|
||||||
|
"""
|
||||||
|
SELECT date, close FROM ohlcv_daily
|
||||||
|
WHERE code = :c AND date >= :td AND date <= :today
|
||||||
|
ORDER BY date ASC
|
||||||
|
LIMIT 1
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"c": code, "td": target_date, "today": today},
|
||||||
|
).first()
|
||||||
|
if not actual_row or actual_row[1] is None:
|
||||||
skipped += 1
|
skipped += 1
|
||||||
continue
|
continue
|
||||||
|
actual_date = actual_row[0]
|
||||||
|
actual = float(actual_row[1])
|
||||||
|
|
||||||
|
base_close_row = conn.execute(
|
||||||
|
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
|
||||||
|
{"c": code, "d": base_date},
|
||||||
|
).first()
|
||||||
|
if not base_close_row or base_close_row[0] is None:
|
||||||
|
skipped += 1
|
||||||
|
continue
|
||||||
|
base_close = float(base_close_row[0])
|
||||||
|
|
||||||
actual_ret = actual / base_close - 1.0
|
actual_ret = actual / base_close - 1.0
|
||||||
actual_dir = _direction_label(actual_ret)
|
actual_dir = _direction_label(actual_ret)
|
||||||
dir_hit = (pred_dir == actual_dir)
|
dir_hit = (pred_dir == actual_dir)
|
||||||
@@ -122,7 +123,8 @@ def match_for_date(d: date) -> MatchSummary:
|
|||||||
{
|
{
|
||||||
"pid": pid,
|
"pid": pid,
|
||||||
"code": code,
|
"code": code,
|
||||||
"d": d,
|
# 실제 매칭된 거래일 (이월된 경우 target_date 와 다를 수 있음)
|
||||||
|
"d": actual_date,
|
||||||
"h": horizon,
|
"h": horizon,
|
||||||
"m": model,
|
"m": model,
|
||||||
"pc": float(point_forecast) if point_forecast is not None else None,
|
"pc": float(point_forecast) if point_forecast is not None else None,
|
||||||
@@ -138,7 +140,7 @@ def match_for_date(d: date) -> MatchSummary:
|
|||||||
already += 1
|
already += 1
|
||||||
|
|
||||||
return MatchSummary(
|
return MatchSummary(
|
||||||
target_date=str(d),
|
today=str(today),
|
||||||
candidates=candidates,
|
candidates=candidates,
|
||||||
matched=matched,
|
matched=matched,
|
||||||
skipped_no_actual=skipped,
|
skipped_no_actual=skipped,
|
||||||
@@ -146,13 +148,19 @@ def match_for_date(d: date) -> MatchSummary:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 하위 호환 alias — 이전 시그니처를 쓰던 호출자 (예: 단일 날짜 매칭 테스트)
|
||||||
|
def match_for_date(d: date) -> MatchSummary:
|
||||||
|
"""legacy: target_date == d 만 매칭하던 동작 → 이제 target_date <= d 전체 처리."""
|
||||||
|
return match_up_to(d)
|
||||||
|
|
||||||
|
|
||||||
def match_today() -> dict[str, Any]:
|
def match_today() -> dict[str, Any]:
|
||||||
"""평일 16:30 KST 호출용. target_date == today (KST) 인 행 매칭."""
|
"""평일 16:30 KST 호출용. target_date <= today (KST) 인 미해결 행 일괄 매칭."""
|
||||||
from datetime import datetime, timezone, timedelta as td
|
from datetime import datetime, timezone, timedelta as td
|
||||||
|
|
||||||
kst = timezone(td(hours=9))
|
kst = timezone(td(hours=9))
|
||||||
today = datetime.now(kst).date()
|
today = datetime.now(kst).date()
|
||||||
summary = match_for_date(today)
|
summary = match_up_to(today)
|
||||||
return {
|
return {
|
||||||
"today": str(today),
|
"today": str(today),
|
||||||
"summary": summary.__dict__,
|
"summary": summary.__dict__,
|
||||||
|
|||||||
@@ -3,12 +3,16 @@
|
|||||||
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
|
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
|
||||||
- ensemble.predict() 로 horizons (1,3,5) 결과 계산
|
- ensemble.predict() 로 horizons (1,3,5) 결과 계산
|
||||||
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
|
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
|
||||||
(대충 calendar 일로 +h * 1.4 — KRX 영업일 추정. Phase 4 단순화: base_date + h 영업일은
|
(주말만 스킵하는 단순 카운트. 공휴일은 match_outcomes 가 "target_date 이후
|
||||||
ohlcv 상의 다음 h 거래일이 아닌, "거래일 카운트" 대신 단순 calendar+h 로 저장하고
|
최초 거래일 종가"로 자동 이월하여 보정.)
|
||||||
매칭 잡에서 ohlcv_daily 에 그 날짜 행이 있는지로 자연 보정.)
|
|
||||||
|
|
||||||
대안 정확도 위해: 매칭 잡은 "예측의 target_date 이 오늘"인 행을 그날 종가와 비교.
|
세 종류의 행을 함께 저장한다:
|
||||||
calendar date 가 비거래일이면 매칭이 안 되니, 매칭 잡은 매일 실행되어 모일 때 처리.
|
- model='ensemble' : 사용자에게 보여주는 최종 예측. user_triggered 플래그 따라감.
|
||||||
|
- model='chronos' : Chronos 단독 (shadow). user_triggered=FALSE 로 항상 적재.
|
||||||
|
- model='lgbm' : LGBM 단독 (shadow). user_triggered=FALSE 로 항상 적재.
|
||||||
|
|
||||||
|
shadow 행은 retrain_weekly 가 모델별 hit_rate 를 비교해 ensemble_weights 를
|
||||||
|
자동 보정하는 입력이 된다.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -27,6 +31,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
KST = timezone(timedelta(hours=9))
|
KST = timezone(timedelta(hours=9))
|
||||||
|
|
||||||
|
# ±0.3% flat band — features.FLAT_BAND, match_outcomes.FLAT_BAND 와 동일.
|
||||||
|
FLAT_BAND = 0.003
|
||||||
|
|
||||||
|
|
||||||
def _next_trading_target(base_date: date, horizon: int) -> date:
|
def _next_trading_target(base_date: date, horizon: int) -> date:
|
||||||
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
|
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
|
||||||
@@ -49,77 +56,194 @@ def _last_trading_date(code: str) -> date | None:
|
|||||||
return row[0] if row and row[0] else None
|
return row[0] if row and row[0] else None
|
||||||
|
|
||||||
|
|
||||||
|
def _direction_label(ret: float) -> str:
|
||||||
|
if ret > FLAT_BAND:
|
||||||
|
return "up"
|
||||||
|
if ret < -FLAT_BAND:
|
||||||
|
return "down"
|
||||||
|
return "flat"
|
||||||
|
|
||||||
|
|
||||||
|
_INSERT_PREDICTION_SQL = text(
|
||||||
|
"""
|
||||||
|
INSERT INTO predictions
|
||||||
|
(code, predicted_at, base_date, target_date, horizon, model,
|
||||||
|
direction, prob_up, prob_flat, prob_down, expected_return,
|
||||||
|
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
|
||||||
|
VALUES
|
||||||
|
(:code, :predicted_at, :base_date, :target_date, :horizon, :model,
|
||||||
|
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
|
||||||
|
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
|
||||||
|
ON CONFLICT (code, base_date, target_date, horizon, model)
|
||||||
|
DO UPDATE SET
|
||||||
|
predicted_at = EXCLUDED.predicted_at,
|
||||||
|
direction = EXCLUDED.direction,
|
||||||
|
prob_up = EXCLUDED.prob_up,
|
||||||
|
prob_flat = EXCLUDED.prob_flat,
|
||||||
|
prob_down = EXCLUDED.prob_down,
|
||||||
|
expected_return = EXCLUDED.expected_return,
|
||||||
|
point_forecast = EXCLUDED.point_forecast,
|
||||||
|
ci_low = EXCLUDED.ci_low,
|
||||||
|
ci_high = EXCLUDED.ci_high,
|
||||||
|
features_snapshot = EXCLUDED.features_snapshot,
|
||||||
|
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
|
||||||
|
RETURNING id
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _insert_prediction(conn, *, model: str, code: str, predicted_at: datetime,
|
||||||
|
base_date: date, target_date: date, horizon: int,
|
||||||
|
direction: str, p_up: float, p_fl: float, p_dn: float,
|
||||||
|
expected_return: float, point: float | None,
|
||||||
|
lo: float | None, hi: float | None,
|
||||||
|
features_snap: dict, user_triggered: bool) -> int | None:
|
||||||
|
row = conn.execute(
|
||||||
|
_INSERT_PREDICTION_SQL,
|
||||||
|
{
|
||||||
|
"code": code,
|
||||||
|
"predicted_at": predicted_at,
|
||||||
|
"base_date": base_date,
|
||||||
|
"target_date": target_date,
|
||||||
|
"horizon": horizon,
|
||||||
|
"model": model,
|
||||||
|
"direction": direction,
|
||||||
|
"p_up": p_up,
|
||||||
|
"p_fl": p_fl,
|
||||||
|
"p_dn": p_dn,
|
||||||
|
"exp_ret": expected_return,
|
||||||
|
"point": point,
|
||||||
|
"lo": lo,
|
||||||
|
"hi": hi,
|
||||||
|
"feats": json.dumps(features_snap),
|
||||||
|
"ut": user_triggered,
|
||||||
|
},
|
||||||
|
).first()
|
||||||
|
return int(row[0]) if row else None
|
||||||
|
|
||||||
|
|
||||||
def predict_and_store(
|
def predict_and_store(
|
||||||
code: str,
|
code: str,
|
||||||
*,
|
*,
|
||||||
horizons: tuple[int, ...] = (1, 3, 5),
|
horizons: tuple[int, ...] = (1, 3, 5),
|
||||||
user_triggered: bool = True,
|
user_triggered: bool = True,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""앙상블 예측 실행 + predictions 테이블 적재. 결과 JSON-serializable dict 반환."""
|
"""앙상블 예측 실행 + predictions 테이블 적재.
|
||||||
|
|
||||||
|
적재 행:
|
||||||
|
- 'ensemble' (user_triggered 인자 반영)
|
||||||
|
- 'chronos' (shadow, user_triggered=FALSE) — Chronos 가 성공했을 때만
|
||||||
|
- 'lgbm' (shadow, user_triggered=FALSE) — LGBM 이 성공한 horizon 만
|
||||||
|
"""
|
||||||
base_date = _last_trading_date(code)
|
base_date = _last_trading_date(code)
|
||||||
if base_date is None:
|
if base_date is None:
|
||||||
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
|
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
|
||||||
|
|
||||||
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
|
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
|
||||||
now = datetime.now(KST)
|
now = datetime.now(KST)
|
||||||
|
base_close = pred.base_close
|
||||||
|
|
||||||
eng = get_engine()
|
eng = get_engine()
|
||||||
saved_ids: list[int] = []
|
saved_ids: dict[str, list[int]] = {"ensemble": [], "chronos": [], "lgbm": []}
|
||||||
with eng.begin() as conn:
|
with eng.begin() as conn:
|
||||||
for step in pred.steps:
|
for step in pred.steps:
|
||||||
target_date = _next_trading_target(base_date, step.horizon)
|
target_date = _next_trading_target(base_date, step.horizon)
|
||||||
|
|
||||||
|
# --- ensemble row ---
|
||||||
features_snap = {
|
features_snap = {
|
||||||
"base_close": pred.base_close,
|
"base_close": base_close,
|
||||||
"sources_used": pred.sources_used,
|
"sources_used": pred.sources_used,
|
||||||
"direction": step.direction,
|
"direction": step.direction,
|
||||||
}
|
}
|
||||||
row = conn.execute(
|
pid = _insert_prediction(
|
||||||
text(
|
conn,
|
||||||
"""
|
model="ensemble",
|
||||||
INSERT INTO predictions
|
code=code,
|
||||||
(code, predicted_at, base_date, target_date, horizon, model,
|
predicted_at=now,
|
||||||
direction, prob_up, prob_flat, prob_down, expected_return,
|
base_date=base_date,
|
||||||
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
|
target_date=target_date,
|
||||||
VALUES
|
horizon=step.horizon,
|
||||||
(:code, :predicted_at, :base_date, :target_date, :horizon, 'ensemble',
|
direction=step.direction,
|
||||||
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
|
p_up=step.prob_up,
|
||||||
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
|
p_fl=step.prob_flat,
|
||||||
ON CONFLICT (code, base_date, target_date, horizon, model)
|
p_dn=step.prob_down,
|
||||||
DO UPDATE SET
|
expected_return=step.expected_return,
|
||||||
predicted_at = EXCLUDED.predicted_at,
|
point=step.point_close,
|
||||||
direction = EXCLUDED.direction,
|
lo=step.ci_low,
|
||||||
prob_up = EXCLUDED.prob_up,
|
hi=step.ci_high,
|
||||||
prob_flat = EXCLUDED.prob_flat,
|
features_snap=features_snap,
|
||||||
prob_down = EXCLUDED.prob_down,
|
user_triggered=user_triggered,
|
||||||
expected_return = EXCLUDED.expected_return,
|
)
|
||||||
point_forecast = EXCLUDED.point_forecast,
|
if pid is not None:
|
||||||
ci_low = EXCLUDED.ci_low,
|
saved_ids["ensemble"].append(pid)
|
||||||
ci_high = EXCLUDED.ci_high,
|
|
||||||
features_snapshot = EXCLUDED.features_snapshot,
|
# --- chronos shadow row ---
|
||||||
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
|
cf = pred.chronos_raw
|
||||||
RETURNING id
|
if cf is not None:
|
||||||
"""
|
c_med = float(cf.median[step.horizon - 1])
|
||||||
),
|
c_q10 = float(cf.q10[step.horizon - 1])
|
||||||
{
|
c_q90 = float(cf.q90[step.horizon - 1])
|
||||||
"code": code,
|
# direction prob: chronos sample 분포
|
||||||
"predicted_at": now,
|
try:
|
||||||
"base_date": base_date,
|
import numpy as np
|
||||||
"target_date": target_date,
|
arr = np.array(cf.samples)[:, step.horizon - 1]
|
||||||
"horizon": step.horizon,
|
ret = arr / base_close - 1.0
|
||||||
"direction": step.direction,
|
cp_up = float((ret > FLAT_BAND).mean())
|
||||||
"p_up": step.prob_up,
|
cp_dn = float((ret < -FLAT_BAND).mean())
|
||||||
"p_fl": step.prob_flat,
|
cp_fl = max(0.0, 1.0 - cp_up - cp_dn)
|
||||||
"p_dn": step.prob_down,
|
except Exception: # noqa: BLE001
|
||||||
"exp_ret": step.expected_return,
|
cp_up = cp_fl = cp_dn = 1.0 / 3.0
|
||||||
"point": step.point_close,
|
exp_ret_c = c_med / base_close - 1.0
|
||||||
"lo": step.ci_low,
|
c_dir = _direction_label(exp_ret_c)
|
||||||
"hi": step.ci_high,
|
pid_c = _insert_prediction(
|
||||||
"feats": json.dumps(features_snap),
|
conn,
|
||||||
"ut": user_triggered,
|
model="chronos",
|
||||||
},
|
code=code,
|
||||||
).first()
|
predicted_at=now,
|
||||||
if row:
|
base_date=base_date,
|
||||||
saved_ids.append(int(row[0]))
|
target_date=target_date,
|
||||||
|
horizon=step.horizon,
|
||||||
|
direction=c_dir,
|
||||||
|
p_up=cp_up,
|
||||||
|
p_fl=cp_fl,
|
||||||
|
p_dn=cp_dn,
|
||||||
|
expected_return=exp_ret_c,
|
||||||
|
point=c_med,
|
||||||
|
lo=c_q10,
|
||||||
|
hi=c_q90,
|
||||||
|
features_snap={"shadow": True, "base_close": base_close},
|
||||||
|
user_triggered=False,
|
||||||
|
)
|
||||||
|
if pid_c is not None:
|
||||||
|
saved_ids["chronos"].append(pid_c)
|
||||||
|
|
||||||
|
# --- lgbm shadow row ---
|
||||||
|
lf = pred.lgbm_raw.get(step.horizon)
|
||||||
|
if lf is not None:
|
||||||
|
l_close = float(lf.predicted_close)
|
||||||
|
exp_ret_l = l_close / base_close - 1.0
|
||||||
|
l_dir = _direction_label(exp_ret_l)
|
||||||
|
pid_l = _insert_prediction(
|
||||||
|
conn,
|
||||||
|
model="lgbm",
|
||||||
|
code=code,
|
||||||
|
predicted_at=now,
|
||||||
|
base_date=base_date,
|
||||||
|
target_date=target_date,
|
||||||
|
horizon=step.horizon,
|
||||||
|
direction=l_dir,
|
||||||
|
p_up=float(lf.prob_up),
|
||||||
|
p_fl=float(lf.prob_flat),
|
||||||
|
p_dn=float(lf.prob_down),
|
||||||
|
expected_return=exp_ret_l,
|
||||||
|
point=l_close,
|
||||||
|
lo=l_close * 0.97,
|
||||||
|
hi=l_close * 1.03,
|
||||||
|
features_snap={"shadow": True, "base_close": base_close},
|
||||||
|
user_triggered=False,
|
||||||
|
)
|
||||||
|
if pid_l is not None:
|
||||||
|
saved_ids["lgbm"].append(pid_l)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"code": code,
|
"code": code,
|
||||||
@@ -133,6 +257,11 @@ def predict_and_store(
|
|||||||
}
|
}
|
||||||
for s in pred.steps
|
for s in pred.steps
|
||||||
],
|
],
|
||||||
"saved_prediction_ids": saved_ids,
|
# UI 는 ensemble id 만 본다. shadow 는 디버깅/검증용으로 별도 키.
|
||||||
|
"saved_prediction_ids": saved_ids["ensemble"],
|
||||||
|
"saved_shadow_ids": {
|
||||||
|
"chronos": saved_ids["chronos"],
|
||||||
|
"lgbm": saved_ids["lgbm"],
|
||||||
|
},
|
||||||
"user_triggered": user_triggered,
|
"user_triggered": user_triggered,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,16 +2,20 @@
|
|||||||
|
|
||||||
일요일 02:00 KST 실행:
|
일요일 02:00 KST 실행:
|
||||||
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
|
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
|
||||||
2. 최근 30일 prediction_outcomes 의 model 별 hit_rate 산출, model_performance 적재.
|
2. 최근 30일 prediction_outcomes 의 (code, model, horizon) 별 hit_rate / mae
|
||||||
3. 같은 30일 윈도우에서 chronos vs lgbm hit_rate 로 ensemble_weights 갱신.
|
산출, model_performance 적재.
|
||||||
(방법: w_chronos = clamp(0.1, hr_c / (hr_c + hr_l), 0.9), w_lgbm = 1 - w_chronos.
|
3. shadow 행 (model='chronos' / 'lgbm') 의 hit_rate 를 비교해서
|
||||||
hit_rate 데이터가 부족하면 default 0.6/0.4 유지.)
|
ensemble_weights 자동 보정.
|
||||||
|
|
||||||
지금은 'ensemble' 모델 단일 종류로 predictions 가 쌓이므로, 가중치 보정은
|
가중치 공식:
|
||||||
chronos 단독 시뮬레이션 / lgbm 단독 시뮬레이션 hit_rate 비교가 진정한 방식인데,
|
w_c = clamp(0.1, hr_c / (hr_c + hr_l), 0.9)
|
||||||
Phase 4 단순화: 'ensemble' 의 종합 hit_rate 만 model_performance 에 기록하고
|
w_l = 1 - w_c
|
||||||
가중치는 default 유지. 진짜 비교는 Phase 7 (chronos 단독 + lgbm 단독 예측을
|
단 sample_count_c < MIN_SAMPLE 또는 sample_count_l < MIN_SAMPLE 이면
|
||||||
shadow 로 같이 적재하는 구조) 로 확장.
|
기본값 유지 (DB row 미생성). hr_c + hr_l == 0 (둘 다 0%) 이면 50:50.
|
||||||
|
|
||||||
|
predict_one 이 매 호출마다 chronos/lgbm shadow 행을 함께 적재하고
|
||||||
|
match_outcomes 가 user_triggered 무관하게 매칭하므로, hit_rate 데이터는
|
||||||
|
사용자가 예측을 한 번이라도 본 종목에 대해 자연스럽게 쌓인다.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -24,12 +28,16 @@ from sqlalchemy import text
|
|||||||
|
|
||||||
from app.db.connection import get_engine
|
from app.db.connection import get_engine
|
||||||
from app.models.lgbm import train_one
|
from app.models.lgbm import train_one
|
||||||
|
from app.models.weights import upsert_weights
|
||||||
from app.seed.seed_tickers import SEED_TICKERS
|
from app.seed.seed_tickers import SEED_TICKERS
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
HORIZONS = (1, 3, 5)
|
HORIZONS = (1, 3, 5)
|
||||||
WINDOW_DAYS = 30
|
WINDOW_DAYS = 30
|
||||||
|
MIN_SAMPLE = 10 # 모델당 최소 매칭 표본
|
||||||
|
W_CHRONOS_MIN = 0.1
|
||||||
|
W_CHRONOS_MAX = 0.9
|
||||||
|
|
||||||
|
|
||||||
def retrain_all() -> list[dict[str, Any]]:
|
def retrain_all() -> list[dict[str, Any]]:
|
||||||
@@ -103,6 +111,82 @@ def record_performance(as_of: date) -> list[dict[str, Any]]:
|
|||||||
return summary
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_weights(as_of: date) -> list[dict[str, Any]]:
|
||||||
|
"""shadow chronos/lgbm hit_rate 로 ensemble_weights 자동 보정.
|
||||||
|
|
||||||
|
반환: (code, horizon, w_chronos, w_lgbm, hr_c, hr_l, n_c, n_l, action) 의
|
||||||
|
리스트. action ∈ {'updated', 'skipped_insufficient', 'skipped_zero'}.
|
||||||
|
"""
|
||||||
|
eng = get_engine()
|
||||||
|
start = as_of - timedelta(days=WINDOW_DAYS)
|
||||||
|
out: list[dict[str, Any]] = []
|
||||||
|
with eng.begin() as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
SELECT code, horizon, model,
|
||||||
|
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
|
||||||
|
COUNT(*) AS n
|
||||||
|
FROM prediction_outcomes
|
||||||
|
WHERE resolved_at >= :start
|
||||||
|
AND model IN ('chronos', 'lgbm')
|
||||||
|
GROUP BY code, horizon, model
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"start": start},
|
||||||
|
).all()
|
||||||
|
|
||||||
|
# (code, horizon) -> {'chronos': (hr, n), 'lgbm': (hr, n)}
|
||||||
|
agg: dict[tuple[str, int], dict[str, tuple[float, int]]] = {}
|
||||||
|
for code, horizon, model, hr, n in rows:
|
||||||
|
key = (code, int(horizon))
|
||||||
|
agg.setdefault(key, {})[str(model)] = (
|
||||||
|
float(hr) if hr is not None else 0.0,
|
||||||
|
int(n),
|
||||||
|
)
|
||||||
|
|
||||||
|
for (code, horizon), m in agg.items():
|
||||||
|
c = m.get("chronos")
|
||||||
|
l = m.get("lgbm")
|
||||||
|
if c is None or l is None:
|
||||||
|
out.append({
|
||||||
|
"code": code, "horizon": horizon,
|
||||||
|
"action": "skipped_missing_model",
|
||||||
|
"have": list(m.keys()),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
hr_c, n_c = c
|
||||||
|
hr_l, n_l = l
|
||||||
|
if n_c < MIN_SAMPLE or n_l < MIN_SAMPLE:
|
||||||
|
out.append({
|
||||||
|
"code": code, "horizon": horizon,
|
||||||
|
"hr_chronos": hr_c, "hr_lgbm": hr_l,
|
||||||
|
"n_chronos": n_c, "n_lgbm": n_l,
|
||||||
|
"action": "skipped_insufficient",
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
total = hr_c + hr_l
|
||||||
|
if total <= 0:
|
||||||
|
w_c = 0.5
|
||||||
|
else:
|
||||||
|
w_c = hr_c / total
|
||||||
|
w_c = max(W_CHRONOS_MIN, min(W_CHRONOS_MAX, w_c))
|
||||||
|
w_l = 1.0 - w_c
|
||||||
|
upsert_weights(
|
||||||
|
code, horizon, w_c, w_l,
|
||||||
|
hit_rate_chronos=hr_c, hit_rate_lgbm=hr_l,
|
||||||
|
sample_count=min(n_c, n_l),
|
||||||
|
)
|
||||||
|
out.append({
|
||||||
|
"code": code, "horizon": horizon,
|
||||||
|
"hr_chronos": hr_c, "hr_lgbm": hr_l,
|
||||||
|
"n_chronos": n_c, "n_lgbm": n_l,
|
||||||
|
"w_chronos": w_c, "w_lgbm": w_l,
|
||||||
|
"action": "updated",
|
||||||
|
})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def run_weekly() -> dict[str, Any]:
|
def run_weekly() -> dict[str, Any]:
|
||||||
"""일요일 02:00 KST 호출 entry-point."""
|
"""일요일 02:00 KST 호출 entry-point."""
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
@@ -113,6 +197,7 @@ def run_weekly() -> dict[str, Any]:
|
|||||||
"as_of": str(as_of),
|
"as_of": str(as_of),
|
||||||
"trained": retrain_all(),
|
"trained": retrain_all(),
|
||||||
"performance": record_performance(as_of),
|
"performance": record_performance(as_of),
|
||||||
|
"weights": adjust_weights(as_of),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,7 @@ export type PredictResponse = {
|
|||||||
sources_used: string[];
|
sources_used: string[];
|
||||||
steps: PredictionStep[];
|
steps: PredictionStep[];
|
||||||
saved_prediction_ids: number[];
|
saved_prediction_ids: number[];
|
||||||
|
saved_shadow_ids?: { chronos: number[]; lgbm: number[] };
|
||||||
user_triggered: boolean;
|
user_triggered: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user