"""MC Domain Filter Proxy. asyncio 기반 TCP 프록시. 동작 순서: 1) 클라이언트가 연결되면 첫 핸드셰이크 패킷을 읽는다. 2) 패킷에서 server_address 를 꺼내 허용 도메인 목록과 대조한다. 3) 허용되면 백엔드 MC 서버에 연결하고, 받은 핸드셰이크 바이트를 그대로 forward 한 뒤 양방향으로 TCP 를 중계한다. 4) 허용되지 않으면 즉시 연결을 종료한다 (응답을 보내지 않음). """ from __future__ import annotations import asyncio import logging import os import sqlite3 import time from pathlib import Path import config as cfg_mod from handshake import HandshakeError, parse_handshake, read_handshake_bytes LOG_DB = Path(os.environ.get("MC_LOG_DB", "/data/logs.db")) RESTART_SIGNAL = Path(os.environ.get("MC_RESTART_SIGNAL", "/data/restart.signal")) logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s", ) log = logging.getLogger("proxy") # --------------------------------------------------------------------------- # Log DB # --------------------------------------------------------------------------- def init_db() -> None: LOG_DB.parent.mkdir(parents=True, exist_ok=True) con = sqlite3.connect(LOG_DB) con.execute("PRAGMA journal_mode=WAL;") con.execute( """ CREATE TABLE IF NOT EXISTS connections ( id INTEGER PRIMARY KEY AUTOINCREMENT, ts REAL NOT NULL, client_ip TEXT NOT NULL, domain TEXT, next_state INTEGER, action TEXT NOT NULL, reason TEXT ) """ ) con.execute("CREATE INDEX IF NOT EXISTS idx_connections_ts ON connections(ts);") con.commit() con.close() def log_event( client_ip: str, domain: str | None, next_state: int | None, action: str, reason: str = "", ) -> None: try: con = sqlite3.connect(LOG_DB, timeout=2) con.execute( "INSERT INTO connections(ts, client_ip, domain, next_state, action, reason) " "VALUES(?,?,?,?,?,?)", (time.time(), client_ip, domain, next_state, action, reason), ) con.commit() con.close() except Exception as exc: # noqa: BLE001 - 로그 실패는 본 흐름을 막지 않는다 log.warning("log write failed: %s", exc) # --------------------------------------------------------------------------- # Runtime state # --------------------------------------------------------------------------- def _signal_mtime() -> float: try: return RESTART_SIGNAL.stat().st_mtime except FileNotFoundError: return 0.0 class ProxyState: def __init__(self) -> None: self.cfg = cfg_mod.load() self.cfg_mtime = cfg_mod.mtime() self.signal_mtime = _signal_mtime() self.listen_port: int = int(self.cfg["proxy"]["listen_port"]) def allowed(self) -> set[str]: return cfg_mod.allowed_domain_set(self.cfg) def backend(self) -> tuple[str, int]: b = self.cfg["backend"] return b["host"], int(b["port"]) def enabled(self) -> bool: return bool(self.cfg.get("proxy", {}).get("enabled", True)) def reload_if_changed(self) -> bool: m = cfg_mod.mtime() if m == self.cfg_mtime: return False try: self.cfg = cfg_mod.load() self.cfg_mtime = m log.info( "config reloaded: enabled=%s backend=%s domains=%s", self.enabled(), self.backend(), sorted(self.allowed()), ) return True except Exception as exc: # noqa: BLE001 log.warning("config reload failed: %s", exc) return False def check_restart_signal(self) -> bool: """`POST /api/proxy/restart` 가 touch 한 신호 파일 변경 여부.""" m = _signal_mtime() if m == self.signal_mtime: return False self.signal_mtime = m log.info("restart signal received") return True # --------------------------------------------------------------------------- # TCP tunneling # --------------------------------------------------------------------------- async def _pipe( reader: asyncio.StreamReader, writer: asyncio.StreamWriter ) -> None: try: while True: data = await reader.read(8192) if not data: break writer.write(data) await writer.drain() except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError): pass except Exception as exc: # noqa: BLE001 log.debug("pipe error: %s", exc) finally: try: writer.close() except Exception: # noqa: BLE001 pass async def handle_client( client_reader: asyncio.StreamReader, client_writer: asyncio.StreamWriter, state: ProxyState, ) -> None: peer = client_writer.get_extra_info("peername") or ("?", 0) client_ip = peer[0] if not state.enabled(): log_event(client_ip, None, None, "blocked", "proxy disabled") client_writer.close() return try: hs_bytes = await asyncio.wait_for( read_handshake_bytes(client_reader), timeout=5 ) hs = parse_handshake(hs_bytes) except (HandshakeError, asyncio.TimeoutError, asyncio.IncompleteReadError, OSError) as exc: log_event(client_ip, None, None, "blocked", f"handshake error: {exc}") log.info("BLOCK %s reason=handshake_error (%s)", client_ip, exc) client_writer.close() return domain = hs.server_address.lower().strip() allowed = state.allowed() if domain not in allowed: log_event(client_ip, domain, hs.next_state, "blocked", "domain not allowed") log.info( "BLOCK %s domain=%r next_state=%d", client_ip, domain, hs.next_state ) client_writer.close() return backend_host, backend_port = state.backend() try: backend_reader, backend_writer = await asyncio.wait_for( asyncio.open_connection(backend_host, backend_port), timeout=5 ) except (OSError, asyncio.TimeoutError) as exc: log_event( client_ip, domain, hs.next_state, "error", f"backend connect failed: {exc}" ) log.warning( "ERROR %s domain=%r backend=%s:%d msg=%s", client_ip, domain, backend_host, backend_port, exc, ) client_writer.close() return log_event(client_ip, domain, hs.next_state, "allowed") log.info( "PASS %s -> %s:%d domain=%r next_state=%d", client_ip, backend_host, backend_port, domain, hs.next_state, ) # 백엔드로 캡처해둔 첫 핸드셰이크 바이트를 그대로 forward backend_writer.write(hs_bytes) try: await backend_writer.drain() except Exception: # noqa: BLE001 client_writer.close() return await asyncio.gather( _pipe(client_reader, backend_writer), _pipe(backend_reader, client_writer), ) # --------------------------------------------------------------------------- # Listener lifecycle # --------------------------------------------------------------------------- class Listener: def __init__(self, state: ProxyState) -> None: self.state = state self.server: asyncio.base_events.Server | None = None async def start(self) -> None: if not self.state.enabled(): log.info("proxy disabled by config; not listening") return self.server = await asyncio.start_server( lambda r, w: handle_client(r, w, self.state), host="0.0.0.0", port=self.state.listen_port, ) log.info("listening on 0.0.0.0:%d", self.state.listen_port) async def stop(self) -> None: if self.server is not None: self.server.close() await self.server.wait_closed() self.server = None log.info("listener stopped") async def restart(self) -> None: await self.stop() await self.start() async def config_watcher(state: ProxyState, listener: Listener) -> None: while True: await asyncio.sleep(2) old_port = state.listen_port old_enabled = state.enabled() config_changed = state.reload_if_changed() signal_changed = state.check_restart_signal() new_port = int(state.cfg["proxy"]["listen_port"]) new_enabled = state.enabled() port_or_enabled_changed = ( config_changed and (new_port != old_port or new_enabled != old_enabled) ) if signal_changed or port_or_enabled_changed: state.listen_port = new_port await listener.restart() async def main() -> None: init_db() state = ProxyState() listener = Listener(state) await listener.start() watcher = asyncio.create_task(config_watcher(state, listener)) try: await asyncio.Event().wait() finally: watcher.cancel() await listener.stop() if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: pass