Compare commits

..

3 Commits

Author SHA1 Message Date
tkrmagid
5e6ce11491 feat(weights): shadow chronos/lgbm 예측 + ensemble_weights 자동 보정
- ensemble.predict() 가 chronos_raw / lgbm_raw 를 함께 반환
- predict_and_store() 가 매 호출마다 3종 행 적재:
    model='ensemble' (user_triggered=인자)
    model='chronos'  (user_triggered=FALSE, shadow)
    model='lgbm'     (user_triggered=FALSE, shadow)
- retrain_weekly.adjust_weights(): 최근 30일 prediction_outcomes 의
  chronos vs lgbm hit_rate 로 ensemble_weights upsert
    w_chronos = clamp(0.1, hr_c/(hr_c+hr_l), 0.9), w_lgbm = 1 - w_chronos
  모델별 표본 < 10 이면 기본값(0.6/0.4) 유지
- API 응답에 saved_shadow_ids 추가 (TS 타입도 동기화)
- README: 동작 모델 메모 섹션을 실제 구현과 일치하도록 갱신

리뷰어 지적 3번 (ensemble_weights 가 영원히 갱신 안됨, upsert_weights 미호출) 해결.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-20 16:27:21 +09:00
tkrmagid
0af556396e fix(match): 주말/공휴일 이월 매칭 (target_date <= today + 최초 거래일 종가)
- match_for_date(d) → match_up_to(today) 로 시맨틱 변경: target_date == d
  대신 target_date <= today AND outcomes 미존재 전부 후보로
- 각 후보마다 ohlcv_daily 에서 target_date 이상 today 이하 범위의 최초
  거래일 행을 actual_close 로 매칭 → 주말/공휴일 자동 이월
- user_triggered 필터 제거: chronos/lgbm shadow 행도 함께 매칭됨
- prediction_outcomes.target_date 에는 실제 매칭된 거래일을 기록
- 하위 호환: match_for_date(d) 는 match_up_to(d) 별칭으로 유지

리뷰어 지적 2번 (공휴일/주말이면 target_date 일치 행이 영원히 미매칭) 해결.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-20 16:27:09 +09:00
tkrmagid
f84b460e54 fix(bootstrap): backend lifespan 에서 DB migrate + symbols 시드 자동화
- main.py 의 lifespan 시작 시 idempotent migration 적용 + symbols 비어있으면 pykrx 로 전 종목 시드
- BOOTSTRAP_DISABLED=1 / SCHEDULER_DISABLED=1 env 로 비활성 가능 (테스트 용)
- 실패해도 서버는 뜨고 /health/db 가 진단 제공

리뷰어 지적 1번 (cold-start 시 /api/refresh 404) 해결.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-20 16:26:59 +09:00
7 changed files with 403 additions and 124 deletions

View File

@@ -155,9 +155,13 @@ stock_chart_site/
## 동작 모델 메모 ## 동작 모델 메모
- 예측 트리거: 사용자가 "예상차트 보기" 누른 종목에 대해 즉시 inference. 결과는 `predictions(user_triggered=TRUE)` 로 저장. - 예측 트리거: 사용자가 "예상차트 보기" 누른 종목에 대해 즉시 inference. 결과는 세 종류 행으로 적재:
- 매칭 배치: 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30, 종가 확정 후 16:00 ~ 16:30 KST 사이) `user_triggered=TRUE` 인 예측 중 `target_date == 오늘 거래일`인 행들에 대해 실제 종가/방향과 매칭 → `prediction_outcomes` 적재. 주말/공휴일이면 다음 거래일로 이월. - `model='ensemble'` (user_triggered=TRUE) — UI 가 표시하는 최종 예측
- 주간 02:00 (일요일): 종목/모델별 최근 30일 hit rate 기반으로 앙상블 가중치를 자동 보정. hit rate가 임계 미만이면 LGBM 재학습. - `model='chronos'` (user_triggered=FALSE, shadow) — Chronos 단독 성능 추적용
- `model='lgbm'` (user_triggered=FALSE, shadow) — LGBM 단독 성능 추적용
- 매칭 배치: 평일 16:30 KST. `target_date <= today AND outcomes 미존재` 인 모든 행에 대해 `target_date` 이상 `today` 이하 범위의 **최초 거래일 종가**를 actual_close 로 사용 → 주말/공휴일 자동 이월. shadow 행도 함께 매칭됨.
- 주간 02:00 (일요일): 시드 10종목 × horizons LGBM 재학습. 최근 30일 prediction_outcomes 의 chronos vs lgbm hit_rate 비교 → `w_chronos = clamp(0.1, hr_c/(hr_c+hr_l), 0.9)` 공식으로 `ensemble_weights` upsert. 모델별 표본이 10 미만이면 기본값(0.6/0.4) 유지.
- DB bootstrap: 백엔드 첫 부팅 시 lifespan 에서 idempotent migration + symbols 시드(비어있을 때만 pykrx 전 종목 적재) 자동 수행. `BOOTSTRAP_DISABLED=1` 로 비활성화 가능.
## 안전/한계 ## 안전/한계

View File

@@ -1,10 +1,12 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from app.api.chart import router as chart_router from app.api.chart import router as chart_router
from app.api.metrics import router as metrics_router from app.api.metrics import router as metrics_router
@@ -13,7 +15,7 @@ from app.api.predict import router as predict_router
from app.api.refresh import router as refresh_router from app.api.refresh import router as refresh_router
from app.api.symbols import router as symbols_router from app.api.symbols import router as symbols_router
from app.config import settings from app.config import settings
from app.db.connection import ping as db_ping from app.db.connection import get_engine, ping as db_ping
from app.fetch import dart as dart_mod from app.fetch import dart as dart_mod
from app.fetch import kis as kis_mod from app.fetch import kis as kis_mod
from app.pipelines.scheduler import shutdown_scheduler, start_scheduler from app.pipelines.scheduler import shutdown_scheduler, start_scheduler
@@ -25,13 +27,55 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _bootstrap_db() -> None:
"""첫 부팅 자동화:
1) migrations/*.sql idempotent 적용 (timescale/pgvector 확장 + 스키마)
2) symbols 테이블 비어있으면 pykrx 로 전 종목 시드 (SEED 10 마크 포함)
BOOTSTRAP_DISABLED=1 이면 스킵 (테스트/CI 용). 어떤 단계든 실패해도 서버는
뜬다 — /health/db 가 진단을 알려준다.
"""
if os.environ.get("BOOTSTRAP_DISABLED") == "1":
logger.info("bootstrap skipped (BOOTSTRAP_DISABLED=1)")
return
# 1) migrations
try:
from app.db.migrate import apply_all
res = apply_all()
logger.info("bootstrap migrate: %s", res)
except Exception: # noqa: BLE001
logger.exception("bootstrap migrate failed")
return # 스키마 없으면 시드 불가
# 2) symbols 시드 (비어있을 때만 — pykrx 호출이 비싸므로 항상 돌리지 않음)
try:
eng = get_engine()
with eng.connect() as conn:
row = conn.execute(text("SELECT COUNT(*) FROM symbols")).first()
count = int(row[0]) if row else 0
if count == 0:
logger.info("symbols empty — running initial seed")
from app.fetch.symbols_seed import seed_symbols
report = seed_symbols()
logger.info("bootstrap seed_symbols: %s", report)
else:
logger.info("symbols already populated (count=%d) — skip seed", count)
except Exception: # noqa: BLE001
logger.exception("bootstrap seed_symbols failed")
@asynccontextmanager @asynccontextmanager
async def lifespan(_: FastAPI): async def lifespan(_: FastAPI):
_bootstrap_db()
# 스케줄러는 옵션. CI/테스트에서 disable 하고 싶으면 SCHEDULER_DISABLED 같은 env 추가 가능. # 스케줄러는 옵션. CI/테스트에서 disable 하고 싶으면 SCHEDULER_DISABLED 같은 env 추가 가능.
try: if os.environ.get("SCHEDULER_DISABLED") == "1":
start_scheduler() logger.info("scheduler skipped (SCHEDULER_DISABLED=1)")
except Exception: # noqa: BLE001 else:
logger.exception("scheduler start failed") try:
start_scheduler()
except Exception: # noqa: BLE001
logger.exception("scheduler start failed")
yield yield
shutdown_scheduler() shutdown_scheduler()

View File

@@ -19,7 +19,7 @@ Chronos 도 실패하면 LGBM 단독으로 진행.
from __future__ import annotations from __future__ import annotations
import logging import logging
from dataclasses import dataclass from dataclasses import dataclass, field
import numpy as np import numpy as np
@@ -53,6 +53,10 @@ class EnsemblePrediction:
horizons: list[int] horizons: list[int]
steps: list[EnsembleStep] steps: list[EnsembleStep]
sources_used: list[str] sources_used: list[str]
# shadow 저장용 원본 출력 (predict_one.py 가 ensemble + chronos 단독 + lgbm 단독
# 3 종을 predictions 에 적재해서 retrain_weekly 가 모델별 hit_rate 비교 가능하게 함).
chronos_raw: ChronosForecast | None = None
lgbm_raw: dict[int, LgbmForecast] = field(default_factory=dict)
def _chronos_direction(samples: list[list[float]], base_close: float, horizon: int) -> tuple[float, float, float]: def _chronos_direction(samples: list[list[float]], base_close: float, horizon: int) -> tuple[float, float, float]:
@@ -90,12 +94,14 @@ def predict(code: str, *, horizons: tuple[int, ...] = (1, 3, 5)) -> EnsemblePred
logger.warning("chronos forecast failed for %s: %s", code, exc) logger.warning("chronos forecast failed for %s: %s", code, exc)
steps: list[EnsembleStep] = [] steps: list[EnsembleStep] = []
lgbm_raw: dict[int, LgbmForecast] = {}
for h in horizons: for h in horizons:
lf: LgbmForecast | None = None lf: LgbmForecast | None = None
try: try:
lf = lgbm_predict(code, h) lf = lgbm_predict(code, h)
if lf is not None: if lf is not None:
sources_used.append(f"lgbm_h{h}") sources_used.append(f"lgbm_h{h}")
lgbm_raw[h] = lf
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
logger.warning("lgbm predict failed for %s h=%d: %s", code, h, exc) logger.warning("lgbm predict failed for %s h=%d: %s", code, h, exc)
@@ -171,4 +177,6 @@ def predict(code: str, *, horizons: tuple[int, ...] = (1, 3, 5)) -> EnsemblePred
horizons=list(horizons), horizons=list(horizons),
steps=steps, steps=steps,
sources_used=sources_used, sources_used=sources_used,
chronos_raw=cf,
lgbm_raw=lgbm_raw,
) )

View File

@@ -1,13 +1,16 @@
"""prediction_outcomes 매칭 배치. """prediction_outcomes 매칭 배치.
평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의 평일 16:30 KST 에 실행. 다음 거래일 장 종료 후 (KRX 정규장 마감 15:30) 의
확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, target_date == 오늘인 확정 종가가 16:00~16:30 사이 pykrx 로 들어온 뒤, 매칭 미해결 예측을 실제
user_triggered=TRUE 예측을 그 종가와 매칭. 종가와 매칭한다.
cold-start / 휴장일 대비: 인자로 받은 target_date 의 ohlcv_daily 에 종가가 이월/공휴일 정책:
없으면 자연스럽게 skip. 다음 거래일 매칭 잡이 다시 시도하면 그 날짜는 target_date 가 calendar date 라서 비거래일이면 ohlcv_daily 에 행이 없다.
여전히 매칭되지 않으므로 (매칭 sql 이 target_date 기준), 영원히 매칭 안되는 그래서 `target_date <= today` 인 미해결 행을 전부 후보로 잡고, 각 행마다
잘못된 calendar date 예측은 cleanup CLI 로 별도 정리 가능 (Phase 7). `target_date <= ohlcv_daily.date <= today` 범위의 최초 거래일 종가로
매칭한다 (=다음 거래일로 자동 이월).
shadow prediction 도 같은 방식으로 매칭한다 (user_triggered 필터 없음).
""" """
from __future__ import annotations from __future__ import annotations
@@ -29,7 +32,7 @@ FLAT_BAND = 0.003
@dataclass @dataclass
class MatchSummary: class MatchSummary:
target_date: str today: str
candidates: int candidates: int
matched: int matched: int
skipped_no_actual: int skipped_no_actual: int
@@ -44,64 +47,62 @@ def _direction_label(ret: float) -> str:
return "flat" return "flat"
def match_for_date(d: date) -> MatchSummary: def match_up_to(today: date) -> MatchSummary:
"""target_date == d 인 user_triggered=TRUE 예측을 매칭.""" """target_date <= today 인 모든 미해결 예측을 매칭.
각 행마다 ohlcv_daily 에서 target_date 이상, today 이하 범위의 최초
거래일 종가를 actual_close 로 사용 — 공휴일/주말 이월 자연 처리.
"""
eng = get_engine() eng = get_engine()
with eng.begin() as conn: with eng.begin() as conn:
# 매칭 대상 예측 + 매칭 안 됐는지 확인.
candidate_rows = conn.execute( candidate_rows = conn.execute(
text( text(
""" """
SELECT p.id, p.code, p.base_date, p.horizon, p.point_forecast, SELECT p.id, p.code, p.base_date, p.target_date, p.horizon,
p.direction, p.model p.point_forecast, p.direction, p.model
FROM predictions p FROM predictions p
LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id LEFT JOIN prediction_outcomes po ON po.prediction_id = p.id
WHERE p.target_date = :d WHERE p.target_date <= :today
AND p.user_triggered = TRUE
AND po.prediction_id IS NULL AND po.prediction_id IS NULL
""" """
), ),
{"d": d}, {"today": today},
).all() ).all()
candidates = len(candidate_rows) candidates = len(candidate_rows)
if not candidates: if not candidates:
return MatchSummary(str(d), 0, 0, 0, 0) return MatchSummary(str(today), 0, 0, 0, 0)
# 종목별로 actual close 조회 (한번에 batch).
codes = list({r[1] for r in candidate_rows})
actual_map: dict[tuple[str, date], float] = {}
for code in codes:
row = conn.execute(
text(
"SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"
),
{"c": code, "d": d},
).first()
if row and row[0] is not None:
actual_map[(code, d)] = float(row[0])
# base_close (각 예측의 base_date 종가) 도 필요 — direction 판정용.
base_close_map: dict[tuple[str, date], float] = {}
for pid, code, base_date, *_ in candidate_rows:
key = (code, base_date)
if key in base_close_map:
continue
row = conn.execute(
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
{"c": code, "d": base_date},
).first()
if row and row[0] is not None:
base_close_map[key] = float(row[0])
matched = 0 matched = 0
skipped = 0 skipped = 0
already = 0 already = 0
for pid, code, base_date, horizon, point_forecast, pred_dir, model in candidate_rows: for pid, code, base_date, target_date, horizon, point_forecast, pred_dir, model in candidate_rows:
actual = actual_map.get((code, d)) # 첫 거래일 종가 (target_date <= date <= today)
base_close = base_close_map.get((code, base_date)) actual_row = conn.execute(
if actual is None or base_close is None: text(
"""
SELECT date, close FROM ohlcv_daily
WHERE code = :c AND date >= :td AND date <= :today
ORDER BY date ASC
LIMIT 1
"""
),
{"c": code, "td": target_date, "today": today},
).first()
if not actual_row or actual_row[1] is None:
skipped += 1 skipped += 1
continue continue
actual_date = actual_row[0]
actual = float(actual_row[1])
base_close_row = conn.execute(
text("SELECT close FROM ohlcv_daily WHERE code = :c AND date = :d"),
{"c": code, "d": base_date},
).first()
if not base_close_row or base_close_row[0] is None:
skipped += 1
continue
base_close = float(base_close_row[0])
actual_ret = actual / base_close - 1.0 actual_ret = actual / base_close - 1.0
actual_dir = _direction_label(actual_ret) actual_dir = _direction_label(actual_ret)
dir_hit = (pred_dir == actual_dir) dir_hit = (pred_dir == actual_dir)
@@ -122,7 +123,8 @@ def match_for_date(d: date) -> MatchSummary:
{ {
"pid": pid, "pid": pid,
"code": code, "code": code,
"d": d, # 실제 매칭된 거래일 (이월된 경우 target_date 와 다를 수 있음)
"d": actual_date,
"h": horizon, "h": horizon,
"m": model, "m": model,
"pc": float(point_forecast) if point_forecast is not None else None, "pc": float(point_forecast) if point_forecast is not None else None,
@@ -138,7 +140,7 @@ def match_for_date(d: date) -> MatchSummary:
already += 1 already += 1
return MatchSummary( return MatchSummary(
target_date=str(d), today=str(today),
candidates=candidates, candidates=candidates,
matched=matched, matched=matched,
skipped_no_actual=skipped, skipped_no_actual=skipped,
@@ -146,13 +148,19 @@ def match_for_date(d: date) -> MatchSummary:
) )
# 하위 호환 alias — 이전 시그니처를 쓰던 호출자 (예: 단일 날짜 매칭 테스트)
def match_for_date(d: date) -> MatchSummary:
"""legacy: target_date == d 만 매칭하던 동작 → 이제 target_date <= d 전체 처리."""
return match_up_to(d)
def match_today() -> dict[str, Any]: def match_today() -> dict[str, Any]:
"""평일 16:30 KST 호출용. target_date == today (KST) 인 매칭.""" """평일 16:30 KST 호출용. target_date <= today (KST) 인 미해결 행 일괄 매칭."""
from datetime import datetime, timezone, timedelta as td from datetime import datetime, timezone, timedelta as td
kst = timezone(td(hours=9)) kst = timezone(td(hours=9))
today = datetime.now(kst).date() today = datetime.now(kst).date()
summary = match_for_date(today) summary = match_up_to(today)
return { return {
"today": str(today), "today": str(today),
"summary": summary.__dict__, "summary": summary.__dict__,

View File

@@ -3,12 +3,16 @@
POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점. POST /api/predict/{code} 에서 호출. 사용자가 "예상차트 보기" 누른 시점.
- ensemble.predict() 로 horizons (1,3,5) 결과 계산 - ensemble.predict() 로 horizons (1,3,5) 결과 계산
- base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일 - base_date = 마지막 ohlcv_daily.date, target_date = base_date + horizon 영업일
(대충 calendar 일로 +h * 1.4 — KRX 영업일 추정. Phase 4 단순화: base_date + h 영업일은 (주말만 스킵하는 단순 카운트. 공휴일은 match_outcomes 가 "target_date 이후
ohlcv 상의 다음 h 거래일이 아닌, "거래일 카운트" 대신 단순 calendar+h 로 저장하고 최초 거래일 종가"로 자동 이월하여 보정.)
매칭 잡에서 ohlcv_daily 에 그 날짜 행이 있는지로 자연 보정.)
대안 정확도 위해: 매칭 잡은 "예측의 target_date 이 오늘"인 행을 그날 종가와 비교. 세 종류의 행을 함께 저장한다:
calendar date 가 비거래일이면 매칭이 안 되니, 매칭 잡은 매일 실행되어 모일 때 처리. - model='ensemble' : 사용자에게 보여주는 최종 예측. user_triggered 플래그 따라감.
- model='chronos' : Chronos 단독 (shadow). user_triggered=FALSE 로 항상 적재.
- model='lgbm' : LGBM 단독 (shadow). user_triggered=FALSE 로 항상 적재.
shadow 행은 retrain_weekly 가 모델별 hit_rate 를 비교해 ensemble_weights 를
자동 보정하는 입력이 된다.
""" """
from __future__ import annotations from __future__ import annotations
@@ -27,6 +31,9 @@ logger = logging.getLogger(__name__)
KST = timezone(timedelta(hours=9)) KST = timezone(timedelta(hours=9))
# ±0.3% flat band — features.FLAT_BAND, match_outcomes.FLAT_BAND 와 동일.
FLAT_BAND = 0.003
def _next_trading_target(base_date: date, horizon: int) -> date: def _next_trading_target(base_date: date, horizon: int) -> date:
"""base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정).""" """base_date + horizon 거래일 (주말만 스킵, 공휴일은 무시 — 매칭잡이 자연 보정)."""
@@ -49,77 +56,194 @@ def _last_trading_date(code: str) -> date | None:
return row[0] if row and row[0] else None return row[0] if row and row[0] else None
def _direction_label(ret: float) -> str:
if ret > FLAT_BAND:
return "up"
if ret < -FLAT_BAND:
return "down"
return "flat"
_INSERT_PREDICTION_SQL = text(
"""
INSERT INTO predictions
(code, predicted_at, base_date, target_date, horizon, model,
direction, prob_up, prob_flat, prob_down, expected_return,
point_forecast, ci_low, ci_high, features_snapshot, user_triggered)
VALUES
(:code, :predicted_at, :base_date, :target_date, :horizon, :model,
:direction, :p_up, :p_fl, :p_dn, :exp_ret,
:point, :lo, :hi, CAST(:feats AS JSONB), :ut)
ON CONFLICT (code, base_date, target_date, horizon, model)
DO UPDATE SET
predicted_at = EXCLUDED.predicted_at,
direction = EXCLUDED.direction,
prob_up = EXCLUDED.prob_up,
prob_flat = EXCLUDED.prob_flat,
prob_down = EXCLUDED.prob_down,
expected_return = EXCLUDED.expected_return,
point_forecast = EXCLUDED.point_forecast,
ci_low = EXCLUDED.ci_low,
ci_high = EXCLUDED.ci_high,
features_snapshot = EXCLUDED.features_snapshot,
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered
RETURNING id
"""
)
def _insert_prediction(conn, *, model: str, code: str, predicted_at: datetime,
base_date: date, target_date: date, horizon: int,
direction: str, p_up: float, p_fl: float, p_dn: float,
expected_return: float, point: float | None,
lo: float | None, hi: float | None,
features_snap: dict, user_triggered: bool) -> int | None:
row = conn.execute(
_INSERT_PREDICTION_SQL,
{
"code": code,
"predicted_at": predicted_at,
"base_date": base_date,
"target_date": target_date,
"horizon": horizon,
"model": model,
"direction": direction,
"p_up": p_up,
"p_fl": p_fl,
"p_dn": p_dn,
"exp_ret": expected_return,
"point": point,
"lo": lo,
"hi": hi,
"feats": json.dumps(features_snap),
"ut": user_triggered,
},
).first()
return int(row[0]) if row else None
def predict_and_store( def predict_and_store(
code: str, code: str,
*, *,
horizons: tuple[int, ...] = (1, 3, 5), horizons: tuple[int, ...] = (1, 3, 5),
user_triggered: bool = True, user_triggered: bool = True,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""앙상블 예측 실행 + predictions 테이블 적재. 결과 JSON-serializable dict 반환.""" """앙상블 예측 실행 + predictions 테이블 적재.
적재 행:
- 'ensemble' (user_triggered 인자 반영)
- 'chronos' (shadow, user_triggered=FALSE) — Chronos 가 성공했을 때만
- 'lgbm' (shadow, user_triggered=FALSE) — LGBM 이 성공한 horizon 만
"""
base_date = _last_trading_date(code) base_date = _last_trading_date(code)
if base_date is None: if base_date is None:
raise RuntimeError(f"no ohlcv_daily for {code}; refresh first") raise RuntimeError(f"no ohlcv_daily for {code}; refresh first")
pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons) pred: EnsemblePrediction = ensemble_predict(code, horizons=horizons)
now = datetime.now(KST) now = datetime.now(KST)
base_close = pred.base_close
eng = get_engine() eng = get_engine()
saved_ids: list[int] = [] saved_ids: dict[str, list[int]] = {"ensemble": [], "chronos": [], "lgbm": []}
with eng.begin() as conn: with eng.begin() as conn:
for step in pred.steps: for step in pred.steps:
target_date = _next_trading_target(base_date, step.horizon) target_date = _next_trading_target(base_date, step.horizon)
# --- ensemble row ---
features_snap = { features_snap = {
"base_close": pred.base_close, "base_close": base_close,
"sources_used": pred.sources_used, "sources_used": pred.sources_used,
"direction": step.direction, "direction": step.direction,
} }
row = conn.execute( pid = _insert_prediction(
text( conn,
""" model="ensemble",
INSERT INTO predictions code=code,
(code, predicted_at, base_date, target_date, horizon, model, predicted_at=now,
direction, prob_up, prob_flat, prob_down, expected_return, base_date=base_date,
point_forecast, ci_low, ci_high, features_snapshot, user_triggered) target_date=target_date,
VALUES horizon=step.horizon,
(:code, :predicted_at, :base_date, :target_date, :horizon, 'ensemble', direction=step.direction,
:direction, :p_up, :p_fl, :p_dn, :exp_ret, p_up=step.prob_up,
:point, :lo, :hi, CAST(:feats AS JSONB), :ut) p_fl=step.prob_flat,
ON CONFLICT (code, base_date, target_date, horizon, model) p_dn=step.prob_down,
DO UPDATE SET expected_return=step.expected_return,
predicted_at = EXCLUDED.predicted_at, point=step.point_close,
direction = EXCLUDED.direction, lo=step.ci_low,
prob_up = EXCLUDED.prob_up, hi=step.ci_high,
prob_flat = EXCLUDED.prob_flat, features_snap=features_snap,
prob_down = EXCLUDED.prob_down, user_triggered=user_triggered,
expected_return = EXCLUDED.expected_return, )
point_forecast = EXCLUDED.point_forecast, if pid is not None:
ci_low = EXCLUDED.ci_low, saved_ids["ensemble"].append(pid)
ci_high = EXCLUDED.ci_high,
features_snapshot = EXCLUDED.features_snapshot, # --- chronos shadow row ---
user_triggered = predictions.user_triggered OR EXCLUDED.user_triggered cf = pred.chronos_raw
RETURNING id if cf is not None:
""" c_med = float(cf.median[step.horizon - 1])
), c_q10 = float(cf.q10[step.horizon - 1])
{ c_q90 = float(cf.q90[step.horizon - 1])
"code": code, # direction prob: chronos sample 분포
"predicted_at": now, try:
"base_date": base_date, import numpy as np
"target_date": target_date, arr = np.array(cf.samples)[:, step.horizon - 1]
"horizon": step.horizon, ret = arr / base_close - 1.0
"direction": step.direction, cp_up = float((ret > FLAT_BAND).mean())
"p_up": step.prob_up, cp_dn = float((ret < -FLAT_BAND).mean())
"p_fl": step.prob_flat, cp_fl = max(0.0, 1.0 - cp_up - cp_dn)
"p_dn": step.prob_down, except Exception: # noqa: BLE001
"exp_ret": step.expected_return, cp_up = cp_fl = cp_dn = 1.0 / 3.0
"point": step.point_close, exp_ret_c = c_med / base_close - 1.0
"lo": step.ci_low, c_dir = _direction_label(exp_ret_c)
"hi": step.ci_high, pid_c = _insert_prediction(
"feats": json.dumps(features_snap), conn,
"ut": user_triggered, model="chronos",
}, code=code,
).first() predicted_at=now,
if row: base_date=base_date,
saved_ids.append(int(row[0])) target_date=target_date,
horizon=step.horizon,
direction=c_dir,
p_up=cp_up,
p_fl=cp_fl,
p_dn=cp_dn,
expected_return=exp_ret_c,
point=c_med,
lo=c_q10,
hi=c_q90,
features_snap={"shadow": True, "base_close": base_close},
user_triggered=False,
)
if pid_c is not None:
saved_ids["chronos"].append(pid_c)
# --- lgbm shadow row ---
lf = pred.lgbm_raw.get(step.horizon)
if lf is not None:
l_close = float(lf.predicted_close)
exp_ret_l = l_close / base_close - 1.0
l_dir = _direction_label(exp_ret_l)
pid_l = _insert_prediction(
conn,
model="lgbm",
code=code,
predicted_at=now,
base_date=base_date,
target_date=target_date,
horizon=step.horizon,
direction=l_dir,
p_up=float(lf.prob_up),
p_fl=float(lf.prob_flat),
p_dn=float(lf.prob_down),
expected_return=exp_ret_l,
point=l_close,
lo=l_close * 0.97,
hi=l_close * 1.03,
features_snap={"shadow": True, "base_close": base_close},
user_triggered=False,
)
if pid_l is not None:
saved_ids["lgbm"].append(pid_l)
return { return {
"code": code, "code": code,
@@ -133,6 +257,11 @@ def predict_and_store(
} }
for s in pred.steps for s in pred.steps
], ],
"saved_prediction_ids": saved_ids, # UI 는 ensemble id 만 본다. shadow 는 디버깅/검증용으로 별도 키.
"saved_prediction_ids": saved_ids["ensemble"],
"saved_shadow_ids": {
"chronos": saved_ids["chronos"],
"lgbm": saved_ids["lgbm"],
},
"user_triggered": user_triggered, "user_triggered": user_triggered,
} }

View File

@@ -2,16 +2,20 @@
일요일 02:00 KST 실행: 일요일 02:00 KST 실행:
1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one). 1. 시드 10종목 × horizon (1,3,5) 별로 LGBM 학습 (train_one).
2. 최근 30일 prediction_outcomes 의 model 별 hit_rate 산출, model_performance 적재. 2. 최근 30일 prediction_outcomes 의 (code, model, horizon) 별 hit_rate / mae
3. 같은 30일 윈도우에서 chronos vs lgbm hit_rate 로 ensemble_weights 갱신. 산출, model_performance 적재.
(방법: w_chronos = clamp(0.1, hr_c / (hr_c + hr_l), 0.9), w_lgbm = 1 - w_chronos. 3. shadow 행 (model='chronos' / 'lgbm') 의 hit_rate 를 비교해서
hit_rate 데이터가 부족하면 default 0.6/0.4 유지.) ensemble_weights 자동 보정.
지금은 'ensemble' 모델 단일 종류로 predictions 가 쌓이므로, 가중치 보정은 가중치 공식:
chronos 단독 시뮬레이션 / lgbm 단독 시뮬레이션 hit_rate 비교가 진정한 방식인데, w_c = clamp(0.1, hr_c / (hr_c + hr_l), 0.9)
Phase 4 단순화: 'ensemble' 의 종합 hit_rate 만 model_performance 에 기록하고 w_l = 1 - w_c
가중치는 default 유지. 진짜 비교는 Phase 7 (chronos 단독 + lgbm 단독 예측을 단 sample_count_c < MIN_SAMPLE 또는 sample_count_l < MIN_SAMPLE 이면
shadow 로 같이 적재하는 구조) 로 확장. 기본값 유지 (DB row 미생성). hr_c + hr_l == 0 (둘 다 0%) 이면 50:50.
predict_one 이 매 호출마다 chronos/lgbm shadow 행을 함께 적재하고
match_outcomes 가 user_triggered 무관하게 매칭하므로, hit_rate 데이터는
사용자가 예측을 한 번이라도 본 종목에 대해 자연스럽게 쌓인다.
""" """
from __future__ import annotations from __future__ import annotations
@@ -24,12 +28,16 @@ from sqlalchemy import text
from app.db.connection import get_engine from app.db.connection import get_engine
from app.models.lgbm import train_one from app.models.lgbm import train_one
from app.models.weights import upsert_weights
from app.seed.seed_tickers import SEED_TICKERS from app.seed.seed_tickers import SEED_TICKERS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
HORIZONS = (1, 3, 5) HORIZONS = (1, 3, 5)
WINDOW_DAYS = 30 WINDOW_DAYS = 30
MIN_SAMPLE = 10 # 모델당 최소 매칭 표본
W_CHRONOS_MIN = 0.1
W_CHRONOS_MAX = 0.9
def retrain_all() -> list[dict[str, Any]]: def retrain_all() -> list[dict[str, Any]]:
@@ -103,6 +111,82 @@ def record_performance(as_of: date) -> list[dict[str, Any]]:
return summary return summary
def adjust_weights(as_of: date) -> list[dict[str, Any]]:
"""shadow chronos/lgbm hit_rate 로 ensemble_weights 자동 보정.
반환: (code, horizon, w_chronos, w_lgbm, hr_c, hr_l, n_c, n_l, action) 의
리스트. action ∈ {'updated', 'skipped_insufficient', 'skipped_zero'}.
"""
eng = get_engine()
start = as_of - timedelta(days=WINDOW_DAYS)
out: list[dict[str, Any]] = []
with eng.begin() as conn:
rows = conn.execute(
text(
"""
SELECT code, horizon, model,
AVG(CASE WHEN direction_hit THEN 1.0 ELSE 0.0 END) AS hit_rate,
COUNT(*) AS n
FROM prediction_outcomes
WHERE resolved_at >= :start
AND model IN ('chronos', 'lgbm')
GROUP BY code, horizon, model
"""
),
{"start": start},
).all()
# (code, horizon) -> {'chronos': (hr, n), 'lgbm': (hr, n)}
agg: dict[tuple[str, int], dict[str, tuple[float, int]]] = {}
for code, horizon, model, hr, n in rows:
key = (code, int(horizon))
agg.setdefault(key, {})[str(model)] = (
float(hr) if hr is not None else 0.0,
int(n),
)
for (code, horizon), m in agg.items():
c = m.get("chronos")
l = m.get("lgbm")
if c is None or l is None:
out.append({
"code": code, "horizon": horizon,
"action": "skipped_missing_model",
"have": list(m.keys()),
})
continue
hr_c, n_c = c
hr_l, n_l = l
if n_c < MIN_SAMPLE or n_l < MIN_SAMPLE:
out.append({
"code": code, "horizon": horizon,
"hr_chronos": hr_c, "hr_lgbm": hr_l,
"n_chronos": n_c, "n_lgbm": n_l,
"action": "skipped_insufficient",
})
continue
total = hr_c + hr_l
if total <= 0:
w_c = 0.5
else:
w_c = hr_c / total
w_c = max(W_CHRONOS_MIN, min(W_CHRONOS_MAX, w_c))
w_l = 1.0 - w_c
upsert_weights(
code, horizon, w_c, w_l,
hit_rate_chronos=hr_c, hit_rate_lgbm=hr_l,
sample_count=min(n_c, n_l),
)
out.append({
"code": code, "horizon": horizon,
"hr_chronos": hr_c, "hr_lgbm": hr_l,
"n_chronos": n_c, "n_lgbm": n_l,
"w_chronos": w_c, "w_lgbm": w_l,
"action": "updated",
})
return out
def run_weekly() -> dict[str, Any]: def run_weekly() -> dict[str, Any]:
"""일요일 02:00 KST 호출 entry-point.""" """일요일 02:00 KST 호출 entry-point."""
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -113,6 +197,7 @@ def run_weekly() -> dict[str, Any]:
"as_of": str(as_of), "as_of": str(as_of),
"trained": retrain_all(), "trained": retrain_all(),
"performance": record_performance(as_of), "performance": record_performance(as_of),
"weights": adjust_weights(as_of),
} }

View File

@@ -72,6 +72,7 @@ export type PredictResponse = {
sources_used: string[]; sources_used: string[];
steps: PredictionStep[]; steps: PredictionStep[];
saved_prediction_ids: number[]; saved_prediction_ids: number[];
saved_shadow_ids?: { chronos: number[]; lgbm: number[] };
user_triggered: boolean; user_triggered: boolean;
}; };