feat(security): validate full MQTT payload schema with Pydantic

Previously only device_id and cpu_percent had explicit checks. The
rest of the payload — mem_percent, disk_percent, gpu_percent,
timestamp, agent_cpu_percent, agent_mem_mb — was trusted as long
as json.loads accepted it, so a malicious or buggy publisher could
push:

  - mem_percent: "<script>alert(1)</script>" (rendered later in
    the WS dashboard / Slack summary as if numeric),
  - disk_percent: NaN (which compares False everywhere and breaks
    downstream chart aggregation),
  - extra keys ("evil": "<!channel>"), persisted in device_state
    and forwarded verbatim to WS clients.

Pydantic Metrics model now enforces the whole frame:
  - device_id pattern (same regex as validate_device_id),
  - percentages bounded to [0, 100],
  - explicit NaN rejection (a finite-value @field_validator on top
    of Field(ge/le), which already excludes inf),
  - timestamp >= 0,
  - extra="forbid" so unknown keys are dropped at the door.

on_message now goes through parse_metrics() which logs a WARNING
with the structured pydantic error list on rejection.
This commit is contained in:
authentik Default Admin 2026-05-17 17:27:26 +00:00
parent adf6a7a1ce
commit 0e816fb966

View file

@ -18,7 +18,7 @@ from urllib.parse import parse_qs, urlsplit
import aiohttp
from gmqtt import Client as MQTTClient
from pydantic import Field
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from websockets import serve # websockets >= 10
@ -73,8 +73,60 @@ logging.basicConfig(
log = logging.getLogger("edgewatch.backend")
_DEVICE_ID_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
_DEVICE_ID_PATTERN = r"^[A-Za-z0-9._-]{1,64}$"
_last_alert: dict[str, float] = {}
class Metrics(BaseModel):
"""Schema for an incoming MQTT metrics frame.
Pydantic rejects everything that does not match: oversize/odd device
ids, NaN/inf, percentages outside [0, 100], missing required fields,
and (via ``extra="forbid"``) any unexpected key a malicious publisher
might try to smuggle in.
"""
model_config = ConfigDict(extra="forbid", str_strip_whitespace=False)
device_id: str = Field(pattern=_DEVICE_ID_PATTERN)
timestamp: int = Field(ge=0)
cpu_percent: float = Field(ge=0, le=100)
mem_percent: float = Field(ge=0, le=100)
disk_percent: float = Field(ge=0, le=100)
gpu_percent: Optional[float] = Field(default=None, ge=0, le=100)
agent_cpu_percent: Optional[float] = Field(default=None, ge=0, le=1000)
agent_mem_mb: Optional[float] = Field(default=None, ge=0)
@field_validator("cpu_percent", "mem_percent", "disk_percent",
"gpu_percent", "agent_cpu_percent", "agent_mem_mb")
@classmethod
def _reject_nonfinite(cls, value):
# Field(ge/le) already excludes inf/-inf, but NaN compares False on
# both sides and would slip through. Reject explicitly.
if value is None:
return value
if not (value == value): # NaN check without importing math
raise ValueError("must be a finite number")
return value
def parse_metrics(raw: str) -> Optional[Metrics]:
"""Validate a raw JSON frame; return a Metrics instance or None.
All failure modes (invalid JSON, schema mismatch, oversize, NaN) get
logged at WARNING with the topic context kept by the caller.
"""
try:
data = json.loads(raw)
except (json.JSONDecodeError, UnicodeDecodeError, TypeError) as err:
log.warning("MQTT JSON decode failed: %s", err)
return None
try:
return Metrics.model_validate(data)
except ValidationError as err:
log.warning("MQTT payload rejected by schema: %s", err.errors())
return None
clients = set()
device_state = {}
last_seen = {}
@ -236,26 +288,22 @@ async def mqtt_loop():
def on_message(_c, topic, payload, _qos, _properties):
"""Handle incoming MQTT messages from devices."""
try:
raw = payload.decode() if isinstance(payload, (bytes, bytearray)) else payload
if len(raw) > settings.max_payload_bytes:
log.warning("MQTT payload too large topic=%s bytes=%d", topic, len(raw))
return
log.debug("MQTT raw topic=%s payload=%s", topic, raw[:100])
data = json.loads(raw)
except Exception:
log.exception("MQTT could not parse payload topic=%s", topic)
raw = payload.decode() if isinstance(payload, (bytes, bytearray)) else payload
if len(raw) > settings.max_payload_bytes:
log.warning("MQTT payload too large topic=%s bytes=%d", topic, len(raw))
return
log.debug("MQTT raw topic=%s payload=%s", topic, raw[:100])
metrics = parse_metrics(raw)
if metrics is None:
return
device_id = validate_device_id(data.get("device_id"))
if device_id is None:
log.warning("MQTT invalid device_id topic=%s raw=%r",
topic, data.get("device_id"))
return
device_id = metrics.device_id
if device_id not in device_state and len(device_state) >= settings.max_devices:
log.warning("MQTT device cap reached cap=%d dropping=%s",
settings.max_devices, device_id)
return
data = metrics.model_dump()
device_state[device_id] = data
last_seen[device_id] = time.time()
@ -266,16 +314,16 @@ async def mqtt_loop():
_schedule(_broadcast(msg))
# 🚨 Slack alert if CPU too high (throttled per device).
cpu = data.get("cpu_percent")
now = time.time()
if should_alert(cpu, device_id, now, _last_alert,
if should_alert(metrics.cpu_percent, device_id, now, _last_alert,
settings.alert_cooldown, settings.cpu_alert_th):
_last_alert[device_id] = now
text = (
f":rotating_light: CPU ALERT at {device_id}"
f"CPU {cpu:.1f}%, RAM {data.get('mem_percent')}%"
f"CPU {metrics.cpu_percent:.1f}%, RAM {metrics.mem_percent}%"
)
log.info("Alert triggered device=%s cpu=%.1f", device_id, cpu)
log.info("Alert triggered device=%s cpu=%.1f",
device_id, metrics.cpu_percent)
_schedule(post_slack(text))
# Attach callbacks