Two related diagnostics from production: 1) "Connection reset" instead of the custom block_message screen. Root cause: writer.close() returned before the kernel flushed the Login Disconnect packet, and the OS sent RST instead of FIN. Fix: write_eof() + await wait_closed() so the FIN goes out after the payload and the client has time to read the chat component. 2) Log entries showing reason "handshake error:" with an empty tail. Root cause: bare OSError() / ConnectionResetError() have empty str(), so the f-string interpolated to nothing. Fix: prepend the exception class name so the reason is always informative.
385 lines
12 KiB
Python
385 lines
12 KiB
Python
"""MC Domain Filter Proxy.
|
|
|
|
asyncio 기반 TCP 프록시. 동작 순서:
|
|
1) 클라이언트가 연결되면 첫 핸드셰이크 패킷을 읽는다.
|
|
2) 패킷에서 server_address 를 꺼내 허용 도메인 목록과 대조한다.
|
|
3) 허용되면 백엔드 MC 서버에 연결하고, 받은 핸드셰이크 바이트를 그대로
|
|
forward 한 뒤 양방향으로 TCP 를 중계한다.
|
|
4) 허용되지 않으면 즉시 연결을 종료한다 (응답을 보내지 않음).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import config as cfg_mod
|
|
from handshake import HandshakeError, parse_handshake, read_handshake_bytes
|
|
|
|
DEFAULT_BLOCK_MESSAGE = "이 서버는 허용된 도메인에서만 접속 가능합니다."
|
|
|
|
|
|
def _encode_varint(n: int) -> bytes:
|
|
out = bytearray()
|
|
while True:
|
|
b = n & 0x7F
|
|
n >>= 7
|
|
if n:
|
|
out.append(b | 0x80)
|
|
else:
|
|
out.append(b)
|
|
break
|
|
return bytes(out)
|
|
|
|
|
|
def build_login_disconnect(message: str) -> bytes:
|
|
"""Login Disconnect (clientbound, login state, packet id 0x00).
|
|
|
|
Body: JSON chat component as length-prefixed UTF-8 string.
|
|
클라이언트는 이 패킷을 받으면 "서버에서 연결을 거부했습니다" 화면 대신
|
|
여기 담긴 chat 컴포넌트를 그대로 보여준다.
|
|
"""
|
|
chat_json = json.dumps(
|
|
{"text": message, "color": "red"}, ensure_ascii=False
|
|
).encode("utf-8")
|
|
body = (
|
|
_encode_varint(0x00)
|
|
+ _encode_varint(len(chat_json))
|
|
+ chat_json
|
|
)
|
|
return _encode_varint(len(body)) + body
|
|
|
|
LOG_DB = Path(os.environ.get("MC_LOG_DB", "/data/logs.db"))
|
|
RESTART_SIGNAL = Path(os.environ.get("MC_RESTART_SIGNAL", "/data/restart.signal"))
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(message)s",
|
|
)
|
|
log = logging.getLogger("proxy")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Log DB
|
|
# ---------------------------------------------------------------------------
|
|
def init_db() -> None:
|
|
LOG_DB.parent.mkdir(parents=True, exist_ok=True)
|
|
con = sqlite3.connect(LOG_DB)
|
|
con.execute("PRAGMA journal_mode=WAL;")
|
|
con.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS connections (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
ts REAL NOT NULL,
|
|
client_ip TEXT NOT NULL,
|
|
domain TEXT,
|
|
next_state INTEGER,
|
|
action TEXT NOT NULL,
|
|
reason TEXT
|
|
)
|
|
"""
|
|
)
|
|
con.execute("CREATE INDEX IF NOT EXISTS idx_connections_ts ON connections(ts);")
|
|
con.commit()
|
|
con.close()
|
|
|
|
|
|
def log_event(
|
|
client_ip: str,
|
|
domain: str | None,
|
|
next_state: int | None,
|
|
action: str,
|
|
reason: str = "",
|
|
) -> None:
|
|
try:
|
|
con = sqlite3.connect(LOG_DB, timeout=2)
|
|
con.execute(
|
|
"INSERT INTO connections(ts, client_ip, domain, next_state, action, reason) "
|
|
"VALUES(?,?,?,?,?,?)",
|
|
(time.time(), client_ip, domain, next_state, action, reason),
|
|
)
|
|
con.commit()
|
|
con.close()
|
|
except Exception as exc: # noqa: BLE001 - 로그 실패는 본 흐름을 막지 않는다
|
|
log.warning("log write failed: %s", exc)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Runtime state
|
|
# ---------------------------------------------------------------------------
|
|
def _signal_mtime() -> float:
|
|
try:
|
|
return RESTART_SIGNAL.stat().st_mtime
|
|
except FileNotFoundError:
|
|
return 0.0
|
|
|
|
|
|
class ProxyState:
|
|
def __init__(self) -> None:
|
|
self.cfg = cfg_mod.load()
|
|
self.cfg_mtime = cfg_mod.mtime()
|
|
self.signal_mtime = _signal_mtime()
|
|
self.listen_port: int = int(self.cfg["proxy"]["listen_port"])
|
|
|
|
def allowed(self) -> set[str]:
|
|
return cfg_mod.allowed_domain_set(self.cfg)
|
|
|
|
def backend(self) -> tuple[str, int]:
|
|
"""기본 백엔드 (도메인 entry 에 backend 가 없을 때 fallback)."""
|
|
b = self.cfg["backend"]
|
|
return b["host"], int(b["port"])
|
|
|
|
def backend_for(self, domain: str) -> tuple[str, int] | None:
|
|
"""주어진 도메인이 활성 화이트리스트에 있으면 라우팅 대상을 돌려준다.
|
|
|
|
도메인 entry 에 `backend.host`/`backend.port` 가 있으면 그 값을 우선,
|
|
없으면 top-level `backend` 로 fallback. 도메인이 비활성이거나 없으면
|
|
None.
|
|
"""
|
|
d = domain.lower().strip()
|
|
for entry in self.cfg.get("allowed_domains", []):
|
|
if entry["domain"].lower().strip() != d:
|
|
continue
|
|
if not entry.get("enabled", True):
|
|
return None
|
|
be = entry.get("backend") or {}
|
|
host = (be.get("host") or "").strip()
|
|
port = be.get("port")
|
|
if host and port:
|
|
return host, int(port)
|
|
return self.backend()
|
|
return None
|
|
|
|
def enabled(self) -> bool:
|
|
return bool(self.cfg.get("proxy", {}).get("enabled", True))
|
|
|
|
def reload_if_changed(self) -> bool:
|
|
m = cfg_mod.mtime()
|
|
if m == self.cfg_mtime:
|
|
return False
|
|
try:
|
|
self.cfg = cfg_mod.load()
|
|
self.cfg_mtime = m
|
|
log.info(
|
|
"config reloaded: enabled=%s backend=%s domains=%s",
|
|
self.enabled(),
|
|
self.backend(),
|
|
sorted(self.allowed()),
|
|
)
|
|
return True
|
|
except Exception as exc: # noqa: BLE001
|
|
log.warning("config reload failed: %s", exc)
|
|
return False
|
|
|
|
def check_restart_signal(self) -> bool:
|
|
"""`POST /api/proxy/restart` 가 touch 한 신호 파일 변경 여부."""
|
|
m = _signal_mtime()
|
|
if m == self.signal_mtime:
|
|
return False
|
|
self.signal_mtime = m
|
|
log.info("restart signal received")
|
|
return True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# TCP tunneling
|
|
# ---------------------------------------------------------------------------
|
|
async def _pipe(
|
|
reader: asyncio.StreamReader, writer: asyncio.StreamWriter
|
|
) -> None:
|
|
try:
|
|
while True:
|
|
data = await reader.read(8192)
|
|
if not data:
|
|
break
|
|
writer.write(data)
|
|
await writer.drain()
|
|
except (ConnectionResetError, BrokenPipeError, asyncio.IncompleteReadError):
|
|
pass
|
|
except Exception as exc: # noqa: BLE001
|
|
log.debug("pipe error: %s", exc)
|
|
finally:
|
|
try:
|
|
writer.close()
|
|
except Exception: # noqa: BLE001
|
|
pass
|
|
|
|
|
|
async def handle_client(
|
|
client_reader: asyncio.StreamReader,
|
|
client_writer: asyncio.StreamWriter,
|
|
state: ProxyState,
|
|
) -> None:
|
|
peer = client_writer.get_extra_info("peername") or ("?", 0)
|
|
client_ip = peer[0]
|
|
|
|
if not state.enabled():
|
|
log_event(client_ip, None, None, "blocked", "proxy disabled")
|
|
client_writer.close()
|
|
return
|
|
|
|
try:
|
|
hs_bytes = await asyncio.wait_for(
|
|
read_handshake_bytes(client_reader), timeout=5
|
|
)
|
|
hs = parse_handshake(hs_bytes)
|
|
except (HandshakeError, asyncio.TimeoutError, asyncio.IncompleteReadError, OSError) as exc:
|
|
# str(exc) 가 빈 문자열인 예외들(OSError(), ConnectionResetError())
|
|
# 도 있어서 class 이름을 함께 남긴다 — 빈 reason 로 보이는 문제 회피.
|
|
reason = f"handshake error: {type(exc).__name__}: {exc}".rstrip(": ")
|
|
log_event(client_ip, None, None, "blocked", reason)
|
|
log.info("BLOCK %s reason=%s", client_ip, reason)
|
|
try:
|
|
client_writer.close()
|
|
await client_writer.wait_closed()
|
|
except Exception: # noqa: BLE001
|
|
pass
|
|
return
|
|
|
|
domain = hs.server_address.lower().strip()
|
|
target = state.backend_for(domain)
|
|
if target is None:
|
|
log_event(client_ip, domain, hs.next_state, "blocked", "domain not allowed")
|
|
log.info(
|
|
"BLOCK %s domain=%r next_state=%d", client_ip, domain, hs.next_state
|
|
)
|
|
# next_state=2 (login) 인 경우 Login Disconnect 패킷으로 메시지 전달,
|
|
# next_state=1 (status/ping) 은 그냥 끊는다 (프록시 존재 자체를 굳이 노출 X).
|
|
if hs.next_state == 2:
|
|
try:
|
|
msg = state.cfg.get("block_message") or DEFAULT_BLOCK_MESSAGE
|
|
client_writer.write(build_login_disconnect(msg))
|
|
await client_writer.drain()
|
|
# FIN 으로 마무리해서 클라이언트가 disconnect 패킷을 다 읽기 전에
|
|
# RST 가 가는 (그러면 "Connection reset" 으로 보인다) 일을 막는다.
|
|
try:
|
|
if client_writer.can_write_eof():
|
|
client_writer.write_eof()
|
|
except (OSError, NotImplementedError):
|
|
pass
|
|
except (OSError, ConnectionResetError, BrokenPipeError):
|
|
pass
|
|
try:
|
|
client_writer.close()
|
|
await client_writer.wait_closed()
|
|
except Exception: # noqa: BLE001
|
|
pass
|
|
return
|
|
|
|
backend_host, backend_port = target
|
|
try:
|
|
backend_reader, backend_writer = await asyncio.wait_for(
|
|
asyncio.open_connection(backend_host, backend_port), timeout=5
|
|
)
|
|
except (OSError, asyncio.TimeoutError) as exc:
|
|
log_event(
|
|
client_ip, domain, hs.next_state, "error", f"backend connect failed: {exc}"
|
|
)
|
|
log.warning(
|
|
"ERROR %s domain=%r backend=%s:%d msg=%s",
|
|
client_ip,
|
|
domain,
|
|
backend_host,
|
|
backend_port,
|
|
exc,
|
|
)
|
|
client_writer.close()
|
|
return
|
|
|
|
log_event(client_ip, domain, hs.next_state, "allowed")
|
|
log.info(
|
|
"PASS %s -> %s:%d domain=%r next_state=%d",
|
|
client_ip,
|
|
backend_host,
|
|
backend_port,
|
|
domain,
|
|
hs.next_state,
|
|
)
|
|
|
|
# 백엔드로 캡처해둔 첫 핸드셰이크 바이트를 그대로 forward
|
|
backend_writer.write(hs_bytes)
|
|
try:
|
|
await backend_writer.drain()
|
|
except Exception: # noqa: BLE001
|
|
client_writer.close()
|
|
return
|
|
|
|
await asyncio.gather(
|
|
_pipe(client_reader, backend_writer),
|
|
_pipe(backend_reader, client_writer),
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Listener lifecycle
|
|
# ---------------------------------------------------------------------------
|
|
class Listener:
|
|
def __init__(self, state: ProxyState) -> None:
|
|
self.state = state
|
|
self.server: asyncio.base_events.Server | None = None
|
|
|
|
async def start(self) -> None:
|
|
if not self.state.enabled():
|
|
log.info("proxy disabled by config; not listening")
|
|
return
|
|
self.server = await asyncio.start_server(
|
|
lambda r, w: handle_client(r, w, self.state),
|
|
host="0.0.0.0",
|
|
port=self.state.listen_port,
|
|
)
|
|
log.info("listening on 0.0.0.0:%d", self.state.listen_port)
|
|
|
|
async def stop(self) -> None:
|
|
if self.server is not None:
|
|
self.server.close()
|
|
await self.server.wait_closed()
|
|
self.server = None
|
|
log.info("listener stopped")
|
|
|
|
async def restart(self) -> None:
|
|
await self.stop()
|
|
await self.start()
|
|
|
|
|
|
async def config_watcher(state: ProxyState, listener: Listener) -> None:
|
|
while True:
|
|
await asyncio.sleep(2)
|
|
old_port = state.listen_port
|
|
old_enabled = state.enabled()
|
|
config_changed = state.reload_if_changed()
|
|
signal_changed = state.check_restart_signal()
|
|
|
|
new_port = int(state.cfg["proxy"]["listen_port"])
|
|
new_enabled = state.enabled()
|
|
|
|
port_or_enabled_changed = (
|
|
config_changed and (new_port != old_port or new_enabled != old_enabled)
|
|
)
|
|
if signal_changed or port_or_enabled_changed:
|
|
state.listen_port = new_port
|
|
await listener.restart()
|
|
|
|
|
|
async def main() -> None:
|
|
init_db()
|
|
state = ProxyState()
|
|
listener = Listener(state)
|
|
await listener.start()
|
|
watcher = asyncio.create_task(config_watcher(state, listener))
|
|
try:
|
|
await asyncio.Event().wait()
|
|
finally:
|
|
watcher.cancel()
|
|
await listener.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
pass
|