feat(security): validate full MQTT payload schema with Pydantic
Previously only device_id and cpu_percent had explicit checks. The
rest of the payload — mem_percent, disk_percent, gpu_percent,
timestamp, agent_cpu_percent, agent_mem_mb — was trusted as long
as json.loads accepted it, so a malicious or buggy publisher could
push:
- mem_percent: "<script>alert(1)</script>" (rendered later in
the WS dashboard / Slack summary as if numeric),
- disk_percent: NaN (which compares False everywhere and breaks
downstream chart aggregation),
- extra keys ("evil": "<!channel>"), persisted in device_state
and forwarded verbatim to WS clients.
Pydantic Metrics model now enforces the whole frame:
- device_id pattern (same regex as validate_device_id),
- percentages bounded to [0, 100],
- explicit NaN rejection (a finite-value @field_validator on top
of Field(ge/le), which already excludes inf),
- timestamp >= 0,
- extra="forbid" so unknown keys are dropped at the door.
on_message now goes through parse_metrics() which logs a WARNING
with the structured pydantic error list on rejection.
This commit is contained in:
parent
adf6a7a1ce
commit
0e816fb966
1 changed files with 67 additions and 19 deletions
86
backend.py
86
backend.py
|
|
@ -18,7 +18,7 @@ from urllib.parse import parse_qs, urlsplit
|
|||
|
||||
import aiohttp
|
||||
from gmqtt import Client as MQTTClient
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from websockets import serve # websockets >= 10
|
||||
|
||||
|
|
@ -73,8 +73,60 @@ logging.basicConfig(
|
|||
log = logging.getLogger("edgewatch.backend")
|
||||
|
||||
_DEVICE_ID_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
|
||||
_DEVICE_ID_PATTERN = r"^[A-Za-z0-9._-]{1,64}$"
|
||||
_last_alert: dict[str, float] = {}
|
||||
|
||||
|
||||
class Metrics(BaseModel):
|
||||
"""Schema for an incoming MQTT metrics frame.
|
||||
|
||||
Pydantic rejects everything that does not match: oversize/odd device
|
||||
ids, NaN/inf, percentages outside [0, 100], missing required fields,
|
||||
and (via ``extra="forbid"``) any unexpected key a malicious publisher
|
||||
might try to smuggle in.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", str_strip_whitespace=False)
|
||||
|
||||
device_id: str = Field(pattern=_DEVICE_ID_PATTERN)
|
||||
timestamp: int = Field(ge=0)
|
||||
cpu_percent: float = Field(ge=0, le=100)
|
||||
mem_percent: float = Field(ge=0, le=100)
|
||||
disk_percent: float = Field(ge=0, le=100)
|
||||
gpu_percent: Optional[float] = Field(default=None, ge=0, le=100)
|
||||
agent_cpu_percent: Optional[float] = Field(default=None, ge=0, le=1000)
|
||||
agent_mem_mb: Optional[float] = Field(default=None, ge=0)
|
||||
|
||||
@field_validator("cpu_percent", "mem_percent", "disk_percent",
|
||||
"gpu_percent", "agent_cpu_percent", "agent_mem_mb")
|
||||
@classmethod
|
||||
def _reject_nonfinite(cls, value):
|
||||
# Field(ge/le) already excludes inf/-inf, but NaN compares False on
|
||||
# both sides and would slip through. Reject explicitly.
|
||||
if value is None:
|
||||
return value
|
||||
if not (value == value): # NaN check without importing math
|
||||
raise ValueError("must be a finite number")
|
||||
return value
|
||||
|
||||
|
||||
def parse_metrics(raw: str) -> Optional[Metrics]:
|
||||
"""Validate a raw JSON frame; return a Metrics instance or None.
|
||||
|
||||
All failure modes (invalid JSON, schema mismatch, oversize, NaN) get
|
||||
logged at WARNING with the topic context kept by the caller.
|
||||
"""
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError, TypeError) as err:
|
||||
log.warning("MQTT JSON decode failed: %s", err)
|
||||
return None
|
||||
try:
|
||||
return Metrics.model_validate(data)
|
||||
except ValidationError as err:
|
||||
log.warning("MQTT payload rejected by schema: %s", err.errors())
|
||||
return None
|
||||
|
||||
clients = set()
|
||||
device_state = {}
|
||||
last_seen = {}
|
||||
|
|
@ -236,26 +288,22 @@ async def mqtt_loop():
|
|||
|
||||
def on_message(_c, topic, payload, _qos, _properties):
|
||||
"""Handle incoming MQTT messages from devices."""
|
||||
try:
|
||||
raw = payload.decode() if isinstance(payload, (bytes, bytearray)) else payload
|
||||
if len(raw) > settings.max_payload_bytes:
|
||||
log.warning("MQTT payload too large topic=%s bytes=%d", topic, len(raw))
|
||||
return
|
||||
log.debug("MQTT raw topic=%s payload=%s", topic, raw[:100])
|
||||
data = json.loads(raw)
|
||||
except Exception:
|
||||
log.exception("MQTT could not parse payload topic=%s", topic)
|
||||
raw = payload.decode() if isinstance(payload, (bytes, bytearray)) else payload
|
||||
if len(raw) > settings.max_payload_bytes:
|
||||
log.warning("MQTT payload too large topic=%s bytes=%d", topic, len(raw))
|
||||
return
|
||||
log.debug("MQTT raw topic=%s payload=%s", topic, raw[:100])
|
||||
|
||||
metrics = parse_metrics(raw)
|
||||
if metrics is None:
|
||||
return
|
||||
|
||||
device_id = validate_device_id(data.get("device_id"))
|
||||
if device_id is None:
|
||||
log.warning("MQTT invalid device_id topic=%s raw=%r",
|
||||
topic, data.get("device_id"))
|
||||
return
|
||||
device_id = metrics.device_id
|
||||
if device_id not in device_state and len(device_state) >= settings.max_devices:
|
||||
log.warning("MQTT device cap reached cap=%d dropping=%s",
|
||||
settings.max_devices, device_id)
|
||||
return
|
||||
data = metrics.model_dump()
|
||||
device_state[device_id] = data
|
||||
last_seen[device_id] = time.time()
|
||||
|
||||
|
|
@ -266,16 +314,16 @@ async def mqtt_loop():
|
|||
_schedule(_broadcast(msg))
|
||||
|
||||
# 🚨 Slack alert if CPU too high (throttled per device).
|
||||
cpu = data.get("cpu_percent")
|
||||
now = time.time()
|
||||
if should_alert(cpu, device_id, now, _last_alert,
|
||||
if should_alert(metrics.cpu_percent, device_id, now, _last_alert,
|
||||
settings.alert_cooldown, settings.cpu_alert_th):
|
||||
_last_alert[device_id] = now
|
||||
text = (
|
||||
f":rotating_light: CPU ALERT at {device_id} — "
|
||||
f"CPU {cpu:.1f}%, RAM {data.get('mem_percent')}%"
|
||||
f"CPU {metrics.cpu_percent:.1f}%, RAM {metrics.mem_percent}%"
|
||||
)
|
||||
log.info("Alert triggered device=%s cpu=%.1f", device_id, cpu)
|
||||
log.info("Alert triggered device=%s cpu=%.1f",
|
||||
device_id, metrics.cpu_percent)
|
||||
_schedule(post_slack(text))
|
||||
|
||||
# Attach callbacks
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue