- ensemble.predict() 가 chronos_raw / lgbm_raw 를 함께 반환
- predict_and_store() 가 매 호출마다 3종 행 적재:
model='ensemble' (user_triggered=인자)
model='chronos' (user_triggered=FALSE, shadow)
model='lgbm' (user_triggered=FALSE, shadow)
- retrain_weekly.adjust_weights(): 최근 30일 prediction_outcomes 의
chronos vs lgbm hit_rate 로 ensemble_weights upsert
w_chronos = clamp(0.1, hr_c/(hr_c+hr_l), 0.9), w_lgbm = 1 - w_chronos
모델별 표본 < 10 이면 기본값(0.6/0.4) 유지
- API 응답에 saved_shadow_ids 추가 (TS 타입도 동기화)
- README: 동작 모델 메모 섹션을 실제 구현과 일치하도록 갱신
리뷰어 지적 3번 (ensemble_weights 가 영원히 갱신 안됨, upsert_weights 미호출) 해결.
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
183 lines
6.3 KiB
Python
183 lines
6.3 KiB
Python
"""Chronos + LGBM 앙상블 추론.
|
|
|
|
final_price[h] = w_c * chronos.median[h-1] + w_l * lgbm.predicted_close
|
|
final_q10[h] = w_c * chronos.q10[h-1] + w_l * lgbm.predicted_close * 0.97
|
|
final_q90[h] = w_c * chronos.q90[h-1] + w_l * lgbm.predicted_close * 1.03
|
|
|
|
LGBM 은 단일 horizon 의 다음 종가(point) 만 주므로, 그 자체로는 신뢰구간이 없음.
|
|
근사로 ±3% band 를 LGBM 의 q10/q90 자리에 사용. Chronos 의 sample 분포가
|
|
주된 신뢰구간 정보 (Chronos 우세하면 ci 가 좁아짐).
|
|
|
|
direction 확률:
|
|
- LGBM 분류기에서 prob_up/flat/down (3-class) 그대로
|
|
- Chronos 는 next-day return 부호 비율: samples.shift1 / base_close - 1 의 부호
|
|
- 둘을 같은 가중치로 평균
|
|
|
|
LGBM 모델이 없으면 Chronos 단독으로 진행 (cold start).
|
|
Chronos 도 실패하면 LGBM 단독으로 진행.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass, field
|
|
|
|
import numpy as np
|
|
|
|
from app.models.chronos import ChronosForecast
|
|
from app.models.chronos import forecast as chronos_forecast
|
|
from app.models.lgbm import LgbmForecast
|
|
from app.models.lgbm import predict_one as lgbm_predict
|
|
from app.models.weights import load_weights
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class EnsembleStep:
|
|
horizon: int # 1..H 거래일 후
|
|
target_idx: int # chronos median 의 0-based 인덱스 (horizon-1)
|
|
point_close: float
|
|
ci_low: float
|
|
ci_high: float
|
|
prob_up: float
|
|
prob_flat: float
|
|
prob_down: float
|
|
direction: str # 'up' / 'flat' / 'down'
|
|
expected_return: float
|
|
|
|
|
|
@dataclass
|
|
class EnsemblePrediction:
|
|
code: str
|
|
base_close: float
|
|
horizons: list[int]
|
|
steps: list[EnsembleStep]
|
|
sources_used: list[str]
|
|
# shadow 저장용 원본 출력 (predict_one.py 가 ensemble + chronos 단독 + lgbm 단독
|
|
# 3 종을 predictions 에 적재해서 retrain_weekly 가 모델별 hit_rate 비교 가능하게 함).
|
|
chronos_raw: ChronosForecast | None = None
|
|
lgbm_raw: dict[int, LgbmForecast] = field(default_factory=dict)
|
|
|
|
|
|
def _chronos_direction(samples: list[list[float]], base_close: float, horizon: int) -> tuple[float, float, float]:
|
|
"""Chronos sample 분포에서 (prob_up, prob_flat, prob_down). ±0.3% flat band."""
|
|
if not samples:
|
|
return 0.33, 0.34, 0.33
|
|
arr = np.array(samples)[:, horizon - 1] # 해당 step 의 sample 값
|
|
ret = arr / base_close - 1.0
|
|
p_up = float((ret > 0.003).mean())
|
|
p_dn = float((ret < -0.003).mean())
|
|
p_fl = 1.0 - p_up - p_dn
|
|
return p_up, p_fl, p_dn
|
|
|
|
|
|
def predict(code: str, *, horizons: tuple[int, ...] = (1, 3, 5)) -> EnsemblePrediction:
|
|
"""한 종목에 대해 horizons 별 앙상블 예측. on-demand 추론용."""
|
|
max_h = max(horizons)
|
|
|
|
# Chronos: 종가 시계열 가져와서 max_h 까지 예측.
|
|
from app.models.features import build_features # local import
|
|
|
|
ff = build_features(code, lookback_days=400, horizons=horizons, with_targets=False)
|
|
df = ff.df
|
|
if df.empty:
|
|
raise RuntimeError(f"no OHLCV data for {code}")
|
|
closes = df["close"].astype(float).tolist()
|
|
base_close = float(closes[-1])
|
|
|
|
sources_used: list[str] = []
|
|
cf: ChronosForecast | None = None
|
|
try:
|
|
cf = chronos_forecast(closes, horizon=max_h, num_samples=30)
|
|
sources_used.append("chronos")
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning("chronos forecast failed for %s: %s", code, exc)
|
|
|
|
steps: list[EnsembleStep] = []
|
|
lgbm_raw: dict[int, LgbmForecast] = {}
|
|
for h in horizons:
|
|
lf: LgbmForecast | None = None
|
|
try:
|
|
lf = lgbm_predict(code, h)
|
|
if lf is not None:
|
|
sources_used.append(f"lgbm_h{h}")
|
|
lgbm_raw[h] = lf
|
|
except Exception as exc: # noqa: BLE001
|
|
logger.warning("lgbm predict failed for %s h=%d: %s", code, h, exc)
|
|
|
|
# 가중치 (DB 없으면 default 0.6/0.4).
|
|
w = load_weights(code, h)
|
|
wc, wl = w.w_chronos, w.w_lgbm
|
|
# 한쪽이 없으면 다른 쪽 전부.
|
|
if cf is None and lf is None:
|
|
raise RuntimeError(f"both chronos & lgbm failed for {code} h={h}")
|
|
if cf is None:
|
|
wc, wl = 0.0, 1.0
|
|
if lf is None:
|
|
wc, wl = 1.0, 0.0
|
|
|
|
if cf is not None:
|
|
c_med = cf.median[h - 1]
|
|
c_q10 = cf.q10[h - 1]
|
|
c_q90 = cf.q90[h - 1]
|
|
else:
|
|
c_med = c_q10 = c_q90 = base_close # not used (wc=0)
|
|
|
|
if lf is not None:
|
|
l_close = lf.predicted_close
|
|
l_lo = l_close * 0.97
|
|
l_hi = l_close * 1.03
|
|
l_pu, l_pf, l_pd = lf.prob_up, lf.prob_flat, lf.prob_down
|
|
else:
|
|
l_close = l_lo = l_hi = base_close
|
|
l_pu = l_pf = l_pd = 0.0
|
|
|
|
point = wc * c_med + wl * l_close
|
|
lo = wc * c_q10 + wl * l_lo
|
|
hi = wc * c_q90 + wl * l_hi
|
|
|
|
if cf is not None:
|
|
cp_up, cp_fl, cp_dn = _chronos_direction(cf.samples, base_close, h)
|
|
else:
|
|
cp_up = cp_fl = cp_dn = 0.0
|
|
|
|
# direction prob: source 마다 weights 동일하게 가중평균
|
|
if lf is not None and cf is not None:
|
|
p_up = 0.5 * cp_up + 0.5 * l_pu
|
|
p_fl = 0.5 * cp_fl + 0.5 * l_pf
|
|
p_dn = 0.5 * cp_dn + 0.5 * l_pd
|
|
elif cf is not None:
|
|
p_up, p_fl, p_dn = cp_up, cp_fl, cp_dn
|
|
else:
|
|
p_up, p_fl, p_dn = l_pu, l_pf, l_pd
|
|
|
|
# 정규화 (혹시 합이 0 가 아닐 때)
|
|
s = max(p_up + p_fl + p_dn, 1e-9)
|
|
p_up, p_fl, p_dn = p_up / s, p_fl / s, p_dn / s
|
|
dir_lbl = "up" if p_up >= max(p_fl, p_dn) else ("down" if p_dn >= p_fl else "flat")
|
|
|
|
steps.append(
|
|
EnsembleStep(
|
|
horizon=h,
|
|
target_idx=h - 1,
|
|
point_close=float(point),
|
|
ci_low=float(lo),
|
|
ci_high=float(hi),
|
|
prob_up=float(p_up),
|
|
prob_flat=float(p_fl),
|
|
prob_down=float(p_dn),
|
|
direction=dir_lbl,
|
|
expected_return=float(point / base_close - 1.0),
|
|
)
|
|
)
|
|
|
|
return EnsemblePrediction(
|
|
code=code,
|
|
base_close=base_close,
|
|
horizons=list(horizons),
|
|
steps=steps,
|
|
sources_used=sources_used,
|
|
chronos_raw=cf,
|
|
lgbm_raw=lgbm_raw,
|
|
)
|