EdgeWatch/backend.py
Richard Nixon 0e816fb966 feat(security): validate full MQTT payload schema with Pydantic
Previously only device_id and cpu_percent had explicit checks. The
rest of the payload — mem_percent, disk_percent, gpu_percent,
timestamp, agent_cpu_percent, agent_mem_mb — was trusted as long
as json.loads accepted it, so a malicious or buggy publisher could
push:

  - mem_percent: "<script>alert(1)</script>" (rendered later in
    the WS dashboard / Slack summary as if numeric),
  - disk_percent: NaN (which compares False everywhere and breaks
    downstream chart aggregation),
  - extra keys ("evil": "<!channel>"), persisted in device_state
    and forwarded verbatim to WS clients.

Pydantic Metrics model now enforces the whole frame:
  - device_id pattern (same regex as validate_device_id),
  - percentages bounded to [0, 100],
  - explicit NaN rejection (a finite-value @field_validator on top
    of Field(ge/le), which already excludes inf),
  - timestamp >= 0,
  - extra="forbid" so unknown keys are dropped at the door.

on_message now goes through parse_metrics() which logs a WARNING
with the structured pydantic error list on rejection.
2026-05-17 17:19:50 +01:00

407 lines
14 KiB
Python

"""
Backend bridge:
- Subscribes to MQTT broker (Mosquitto / AWS IoT Core).
- Broadcasts device metrics to WebSocket clients.
- Sends alerts and summaries to Slack via webhook.
"""
import asyncio
import json
import logging
import os
import re
import secrets
import ssl
import time
from typing import Optional
from urllib.parse import parse_qs, urlsplit
import aiohttp
from gmqtt import Client as MQTTClient
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from websockets import serve # websockets >= 10
class BackendSettings(BaseSettings):
"""Typed configuration. Values are read from environment variables or
a ``.env`` file at startup and validated by pydantic. A bad value
(``WS_PORT=abc``, ``CPU_ALERT_TH=200``) fails fast with a readable
message instead of crashing deep inside a callback.
"""
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
# MQTT
mqtt_broker: str = ""
mqtt_port: int = Field(default=8883, ge=1, le=65535)
topic: str = "devices/+/metrics"
ca_cert: Optional[str] = None
client_cert: Optional[str] = None
client_key: Optional[str] = None
mqtt_backoff_min: int = Field(default=1, ge=1)
mqtt_backoff_max: int = Field(default=60, ge=1)
# WebSocket
ws_host: str = "0.0.0.0"
ws_port: int = Field(default=6789, ge=1, le=65535)
# Shared token required as ?token=... on the WS handshake. Empty = dev.
ws_auth_token: str = ""
# App
prune_seconds: int = Field(default=30, ge=0)
cpu_alert_th: float = Field(default=90.0, ge=0, le=100)
slack_webhook_url: Optional[str] = None
summary_interval: int = Field(default=60, ge=1)
# Hardening against untrusted MQTT input.
max_payload_bytes: int = Field(default=16 * 1024, ge=64)
max_devices: int = Field(default=1000, ge=1)
alert_cooldown: int = Field(default=60, ge=0)
log_level: str = "INFO"
settings = BackendSettings()
logging.basicConfig(
level=settings.log_level.upper(),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
log = logging.getLogger("edgewatch.backend")
_DEVICE_ID_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
_DEVICE_ID_PATTERN = r"^[A-Za-z0-9._-]{1,64}$"
_last_alert: dict[str, float] = {}
class Metrics(BaseModel):
"""Schema for an incoming MQTT metrics frame.
Pydantic rejects everything that does not match: oversize/odd device
ids, NaN/inf, percentages outside [0, 100], missing required fields,
and (via ``extra="forbid"``) any unexpected key a malicious publisher
might try to smuggle in.
"""
model_config = ConfigDict(extra="forbid", str_strip_whitespace=False)
device_id: str = Field(pattern=_DEVICE_ID_PATTERN)
timestamp: int = Field(ge=0)
cpu_percent: float = Field(ge=0, le=100)
mem_percent: float = Field(ge=0, le=100)
disk_percent: float = Field(ge=0, le=100)
gpu_percent: Optional[float] = Field(default=None, ge=0, le=100)
agent_cpu_percent: Optional[float] = Field(default=None, ge=0, le=1000)
agent_mem_mb: Optional[float] = Field(default=None, ge=0)
@field_validator("cpu_percent", "mem_percent", "disk_percent",
"gpu_percent", "agent_cpu_percent", "agent_mem_mb")
@classmethod
def _reject_nonfinite(cls, value):
# Field(ge/le) already excludes inf/-inf, but NaN compares False on
# both sides and would slip through. Reject explicitly.
if value is None:
return value
if not (value == value): # NaN check without importing math
raise ValueError("must be a finite number")
return value
def parse_metrics(raw: str) -> Optional[Metrics]:
"""Validate a raw JSON frame; return a Metrics instance or None.
All failure modes (invalid JSON, schema mismatch, oversize, NaN) get
logged at WARNING with the topic context kept by the caller.
"""
try:
data = json.loads(raw)
except (json.JSONDecodeError, UnicodeDecodeError, TypeError) as err:
log.warning("MQTT JSON decode failed: %s", err)
return None
try:
return Metrics.model_validate(data)
except ValidationError as err:
log.warning("MQTT payload rejected by schema: %s", err.errors())
return None
clients = set()
device_state = {}
last_seen = {}
def filter_active(state, seen, now, prune_seconds):
"""Pure helper: return devices whose last-seen is within ``prune_seconds``.
Kept pure so the pruning rule is unit-testable without touching globals.
"""
return {
did: payload
for did, payload in state.items()
if (now - seen.get(did, 0)) <= prune_seconds
}
def active_snapshot():
"""Return only devices that have updated within ``settings.prune_seconds``."""
return filter_active(device_state, last_seen, time.time(), settings.prune_seconds)
def validate_device_id(raw):
"""Return ``raw`` if it is a safe device id, else ``None``.
Restricted to ``[A-Za-z0-9._-]{1,64}`` so it cannot smuggle Slack mrkdwn
metacharacters, newlines or unbounded lengths through the alert path.
"""
if isinstance(raw, str) and _DEVICE_ID_RE.match(raw):
return raw
return None
def should_alert(cpu, device_id, now, last_alert_map, cooldown, threshold):
"""Decide whether a Slack alert is warranted for this sample.
Pure function: same inputs → same answer. Guards against bool sneaking
in via isinstance(int), NaN/inf and per-device flood.
"""
return (
isinstance(cpu, (int, float))
and not isinstance(cpu, bool)
and 0 <= cpu <= 100
and cpu >= threshold
and (now - last_alert_map.get(device_id, 0)) > cooldown
)
def next_backoff(current, ceiling):
"""Double the backoff but never exceed ``ceiling``. Pure for testability."""
return min(current * 2, ceiling)
def _ws_token_ok(websocket) -> bool:
"""Return True if the handshake carries a valid shared token.
When ``settings.ws_auth_token`` is empty the server is in open mode and
accepts any client (useful for local dev). When set, the client must
connect to ``ws://host:port/?token=<value>`` and the comparison is
constant-time.
"""
if not settings.ws_auth_token:
return True
# websockets >= 11 exposes the full request-target (path + query) via
# connection.request.path; older versions kept it on connection.path.
request = getattr(websocket, "request", None)
target = getattr(request, "path", None) or getattr(websocket, "path", "")
token = (parse_qs(urlsplit(target).query).get("token") or [""])[0]
return bool(token) and secrets.compare_digest(token, settings.ws_auth_token)
async def ws_handler(websocket):
"""Handle WebSocket connections and broadcast initial snapshot."""
if not _ws_token_ok(websocket):
log.warning("WS rejecting unauthorized connection")
await websocket.close(code=1008, reason="unauthorized")
return
clients.add(websocket)
log.info("WS connected total=%d", len(clients))
try:
snap = active_snapshot()
if snap:
log.info("WS sending initial snapshot devices=%d", len(snap))
await websocket.send(json.dumps(snap))
async for _ in websocket:
pass
except Exception:
log.exception("WS handler error")
finally:
clients.discard(websocket)
log.info("WS disconnected total=%d", len(clients))
async def post_slack(text: str):
"""Send a notification to Slack using the webhook."""
if not settings.slack_webhook_url:
log.debug("Slack webhook URL not configured; skipping")
return
try:
async with aiohttp.ClientSession() as session:
resp = await session.post(
settings.slack_webhook_url, json={"text": text}, timeout=10
)
body = await resp.text()
log.info("Slack post status=%s", resp.status)
log.debug("Slack response body=%s", body)
except aiohttp.ClientError:
log.exception("Slack network error")
except Exception:
log.exception("Slack unexpected error")
# Keep strong refs to background tasks; without this, the GC can collect a
# task before it runs and exceptions vanish into "task was never awaited".
_background_tasks: set[asyncio.Task] = set()
def _schedule(coro):
"""Fire-and-forget helper that retains the task and logs exceptions."""
task = asyncio.create_task(coro)
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
async def _broadcast(msg: str):
"""Send a message to every WS client in parallel; drop dead ones."""
snapshot = list(clients)
results = await asyncio.gather(
*(ws.send(msg) for ws in snapshot),
return_exceptions=True,
)
for ws, result in zip(snapshot, results):
if isinstance(result, Exception):
log.warning("WS dropping client: %s", result)
clients.discard(ws)
async def mqtt_loop():
"""Maintain an MQTT subscription, reconnecting with exponential backoff."""
client = MQTTClient("backend-bridge")
disconnected = asyncio.Event()
# ---------------- Callbacks (must be sync defs) ----------------
def on_connect(c, flags, rc, properties):
log.info("MQTT connected rc=%s flags=%s props=%s", rc, flags, properties)
try:
c.subscribe(settings.topic, qos=1)
log.info("MQTT subscribe requested topic=%s", settings.topic)
except Exception:
log.exception("MQTT subscribe failed")
def on_disconnect(_c, packet, exc=None):
log.warning("MQTT disconnected packet=%s exc=%s", packet, exc)
disconnected.set()
def on_subscribe(_c, mid, qos, properties):
log.info("MQTT subscribed mid=%s qos=%s props=%s", mid, qos, properties)
def on_message(_c, topic, payload, _qos, _properties):
"""Handle incoming MQTT messages from devices."""
raw = payload.decode() if isinstance(payload, (bytes, bytearray)) else payload
if len(raw) > settings.max_payload_bytes:
log.warning("MQTT payload too large topic=%s bytes=%d", topic, len(raw))
return
log.debug("MQTT raw topic=%s payload=%s", topic, raw[:100])
metrics = parse_metrics(raw)
if metrics is None:
return
device_id = metrics.device_id
if device_id not in device_state and len(device_state) >= settings.max_devices:
log.warning("MQTT device cap reached cap=%d dropping=%s",
settings.max_devices, device_id)
return
data = metrics.model_dump()
device_state[device_id] = data
last_seen[device_id] = time.time()
# broadcast to WebSocket clients
if clients:
msg = json.dumps({device_id: data})
log.debug("WS broadcasting device=%s clients=%d", device_id, len(clients))
_schedule(_broadcast(msg))
# 🚨 Slack alert if CPU too high (throttled per device).
now = time.time()
if should_alert(metrics.cpu_percent, device_id, now, _last_alert,
settings.alert_cooldown, settings.cpu_alert_th):
_last_alert[device_id] = now
text = (
f":rotating_light: CPU ALERT at {device_id}"
f"CPU {metrics.cpu_percent:.1f}%, RAM {metrics.mem_percent}%"
)
log.info("Alert triggered device=%s cpu=%.1f",
device_id, metrics.cpu_percent)
_schedule(post_slack(text))
# Attach callbacks
client.on_connect = on_connect
client.on_disconnect = on_disconnect
client.on_subscribe = on_subscribe
client.on_message = on_message
# TLS context is reusable across reconnects; build it once.
ssl_ctx = ssl.create_default_context()
if settings.ca_cert:
ssl_ctx.load_verify_locations(settings.ca_cert)
if settings.client_cert and settings.client_key:
ssl_ctx.load_cert_chain(
certfile=settings.client_cert,
keyfile=settings.client_key,
)
backoff = settings.mqtt_backoff_min
while True:
try:
log.info("MQTT connecting host=%s port=%s",
settings.mqtt_broker, settings.mqtt_port)
disconnected.clear()
await client.connect(settings.mqtt_broker, port=settings.mqtt_port,
ssl=ssl_ctx)
log.info("MQTT connect() returned")
backoff = settings.mqtt_backoff_min
await disconnected.wait()
except asyncio.CancelledError:
raise
except Exception:
log.exception("MQTT connect/run failed")
log.info("MQTT reconnecting in %ds", backoff)
try:
await asyncio.sleep(backoff)
except asyncio.CancelledError:
raise
backoff = next_backoff(backoff, settings.mqtt_backoff_max)
async def slack_summary_loop():
"""Periodically send a summary of active devices to Slack."""
while True:
await asyncio.sleep(settings.summary_interval)
snapshot = active_snapshot()
if not snapshot:
log.debug("Slack no active devices to summarize")
continue
lines = ["Metrics summary (last interval):", ""]
for did, data in snapshot.items():
lines.append(
f"{did:15} | CPU {data.get('cpu_percent', '?')}% "
f"| RAM {data.get('mem_percent', '?')}% "
f"| Disk {data.get('disk_percent', '?')}% "
f"| GPU {data.get('gpu_percent', '?')}"
)
formatted = "```\n" + "\n".join(lines) + "\n```"
log.info("Slack sending periodic summary devices=%d", len(snapshot))
await post_slack(formatted)
async def main():
"""Start WebSocket server, MQTT loop, and Slack summary loop."""
async with serve(ws_handler, settings.ws_host, settings.ws_port):
log.info("WS server listening url=ws://%s:%s",
settings.ws_host, settings.ws_port)
await asyncio.gather(
mqtt_loop(),
slack_summary_loop(),
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
log.info("Backend stopped by user")