Previously only device_id and cpu_percent had explicit checks. The
rest of the payload — mem_percent, disk_percent, gpu_percent,
timestamp, agent_cpu_percent, agent_mem_mb — was trusted as long
as json.loads accepted it, so a malicious or buggy publisher could
push:
- mem_percent: "<script>alert(1)</script>" (rendered later in
the WS dashboard / Slack summary as if numeric),
- disk_percent: NaN (which compares False everywhere and breaks
downstream chart aggregation),
- extra keys ("evil": "<!channel>"), persisted in device_state
and forwarded verbatim to WS clients.
Pydantic Metrics model now enforces the whole frame:
- device_id pattern (same regex as validate_device_id),
- percentages bounded to [0, 100],
- explicit NaN rejection (a finite-value @field_validator on top
of Field(ge/le), which already excludes inf),
- timestamp >= 0,
- extra="forbid" so unknown keys are dropped at the door.
on_message now goes through parse_metrics() which logs a WARNING
with the structured pydantic error list on rejection.
407 lines
14 KiB
Python
407 lines
14 KiB
Python
"""
|
|
Backend bridge:
|
|
- Subscribes to MQTT broker (Mosquitto / AWS IoT Core).
|
|
- Broadcasts device metrics to WebSocket clients.
|
|
- Sends alerts and summaries to Slack via webhook.
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import secrets
|
|
import ssl
|
|
import time
|
|
from typing import Optional
|
|
from urllib.parse import parse_qs, urlsplit
|
|
|
|
import aiohttp
|
|
from gmqtt import Client as MQTTClient
|
|
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
from websockets import serve # websockets >= 10
|
|
|
|
|
|
class BackendSettings(BaseSettings):
|
|
"""Typed configuration. Values are read from environment variables or
|
|
a ``.env`` file at startup and validated by pydantic. A bad value
|
|
(``WS_PORT=abc``, ``CPU_ALERT_TH=200``) fails fast with a readable
|
|
message instead of crashing deep inside a callback.
|
|
"""
|
|
|
|
model_config = SettingsConfigDict(
|
|
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
|
)
|
|
|
|
# MQTT
|
|
mqtt_broker: str = ""
|
|
mqtt_port: int = Field(default=8883, ge=1, le=65535)
|
|
topic: str = "devices/+/metrics"
|
|
ca_cert: Optional[str] = None
|
|
client_cert: Optional[str] = None
|
|
client_key: Optional[str] = None
|
|
mqtt_backoff_min: int = Field(default=1, ge=1)
|
|
mqtt_backoff_max: int = Field(default=60, ge=1)
|
|
|
|
# WebSocket
|
|
ws_host: str = "0.0.0.0"
|
|
ws_port: int = Field(default=6789, ge=1, le=65535)
|
|
# Shared token required as ?token=... on the WS handshake. Empty = dev.
|
|
ws_auth_token: str = ""
|
|
|
|
# App
|
|
prune_seconds: int = Field(default=30, ge=0)
|
|
cpu_alert_th: float = Field(default=90.0, ge=0, le=100)
|
|
slack_webhook_url: Optional[str] = None
|
|
summary_interval: int = Field(default=60, ge=1)
|
|
|
|
# Hardening against untrusted MQTT input.
|
|
max_payload_bytes: int = Field(default=16 * 1024, ge=64)
|
|
max_devices: int = Field(default=1000, ge=1)
|
|
alert_cooldown: int = Field(default=60, ge=0)
|
|
|
|
log_level: str = "INFO"
|
|
|
|
|
|
settings = BackendSettings()
|
|
|
|
logging.basicConfig(
|
|
level=settings.log_level.upper(),
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
)
|
|
log = logging.getLogger("edgewatch.backend")
|
|
|
|
_DEVICE_ID_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
|
|
_DEVICE_ID_PATTERN = r"^[A-Za-z0-9._-]{1,64}$"
|
|
_last_alert: dict[str, float] = {}
|
|
|
|
|
|
class Metrics(BaseModel):
|
|
"""Schema for an incoming MQTT metrics frame.
|
|
|
|
Pydantic rejects everything that does not match: oversize/odd device
|
|
ids, NaN/inf, percentages outside [0, 100], missing required fields,
|
|
and (via ``extra="forbid"``) any unexpected key a malicious publisher
|
|
might try to smuggle in.
|
|
"""
|
|
|
|
model_config = ConfigDict(extra="forbid", str_strip_whitespace=False)
|
|
|
|
device_id: str = Field(pattern=_DEVICE_ID_PATTERN)
|
|
timestamp: int = Field(ge=0)
|
|
cpu_percent: float = Field(ge=0, le=100)
|
|
mem_percent: float = Field(ge=0, le=100)
|
|
disk_percent: float = Field(ge=0, le=100)
|
|
gpu_percent: Optional[float] = Field(default=None, ge=0, le=100)
|
|
agent_cpu_percent: Optional[float] = Field(default=None, ge=0, le=1000)
|
|
agent_mem_mb: Optional[float] = Field(default=None, ge=0)
|
|
|
|
@field_validator("cpu_percent", "mem_percent", "disk_percent",
|
|
"gpu_percent", "agent_cpu_percent", "agent_mem_mb")
|
|
@classmethod
|
|
def _reject_nonfinite(cls, value):
|
|
# Field(ge/le) already excludes inf/-inf, but NaN compares False on
|
|
# both sides and would slip through. Reject explicitly.
|
|
if value is None:
|
|
return value
|
|
if not (value == value): # NaN check without importing math
|
|
raise ValueError("must be a finite number")
|
|
return value
|
|
|
|
|
|
def parse_metrics(raw: str) -> Optional[Metrics]:
|
|
"""Validate a raw JSON frame; return a Metrics instance or None.
|
|
|
|
All failure modes (invalid JSON, schema mismatch, oversize, NaN) get
|
|
logged at WARNING with the topic context kept by the caller.
|
|
"""
|
|
try:
|
|
data = json.loads(raw)
|
|
except (json.JSONDecodeError, UnicodeDecodeError, TypeError) as err:
|
|
log.warning("MQTT JSON decode failed: %s", err)
|
|
return None
|
|
try:
|
|
return Metrics.model_validate(data)
|
|
except ValidationError as err:
|
|
log.warning("MQTT payload rejected by schema: %s", err.errors())
|
|
return None
|
|
|
|
clients = set()
|
|
device_state = {}
|
|
last_seen = {}
|
|
|
|
|
|
def filter_active(state, seen, now, prune_seconds):
|
|
"""Pure helper: return devices whose last-seen is within ``prune_seconds``.
|
|
|
|
Kept pure so the pruning rule is unit-testable without touching globals.
|
|
"""
|
|
return {
|
|
did: payload
|
|
for did, payload in state.items()
|
|
if (now - seen.get(did, 0)) <= prune_seconds
|
|
}
|
|
|
|
|
|
def active_snapshot():
|
|
"""Return only devices that have updated within ``settings.prune_seconds``."""
|
|
return filter_active(device_state, last_seen, time.time(), settings.prune_seconds)
|
|
|
|
|
|
def validate_device_id(raw):
|
|
"""Return ``raw`` if it is a safe device id, else ``None``.
|
|
|
|
Restricted to ``[A-Za-z0-9._-]{1,64}`` so it cannot smuggle Slack mrkdwn
|
|
metacharacters, newlines or unbounded lengths through the alert path.
|
|
"""
|
|
if isinstance(raw, str) and _DEVICE_ID_RE.match(raw):
|
|
return raw
|
|
return None
|
|
|
|
|
|
def should_alert(cpu, device_id, now, last_alert_map, cooldown, threshold):
|
|
"""Decide whether a Slack alert is warranted for this sample.
|
|
|
|
Pure function: same inputs → same answer. Guards against bool sneaking
|
|
in via isinstance(int), NaN/inf and per-device flood.
|
|
"""
|
|
return (
|
|
isinstance(cpu, (int, float))
|
|
and not isinstance(cpu, bool)
|
|
and 0 <= cpu <= 100
|
|
and cpu >= threshold
|
|
and (now - last_alert_map.get(device_id, 0)) > cooldown
|
|
)
|
|
|
|
|
|
def next_backoff(current, ceiling):
|
|
"""Double the backoff but never exceed ``ceiling``. Pure for testability."""
|
|
return min(current * 2, ceiling)
|
|
|
|
|
|
def _ws_token_ok(websocket) -> bool:
|
|
"""Return True if the handshake carries a valid shared token.
|
|
|
|
When ``settings.ws_auth_token`` is empty the server is in open mode and
|
|
accepts any client (useful for local dev). When set, the client must
|
|
connect to ``ws://host:port/?token=<value>`` and the comparison is
|
|
constant-time.
|
|
"""
|
|
if not settings.ws_auth_token:
|
|
return True
|
|
# websockets >= 11 exposes the full request-target (path + query) via
|
|
# connection.request.path; older versions kept it on connection.path.
|
|
request = getattr(websocket, "request", None)
|
|
target = getattr(request, "path", None) or getattr(websocket, "path", "")
|
|
token = (parse_qs(urlsplit(target).query).get("token") or [""])[0]
|
|
return bool(token) and secrets.compare_digest(token, settings.ws_auth_token)
|
|
|
|
|
|
async def ws_handler(websocket):
|
|
"""Handle WebSocket connections and broadcast initial snapshot."""
|
|
if not _ws_token_ok(websocket):
|
|
log.warning("WS rejecting unauthorized connection")
|
|
await websocket.close(code=1008, reason="unauthorized")
|
|
return
|
|
clients.add(websocket)
|
|
log.info("WS connected total=%d", len(clients))
|
|
try:
|
|
snap = active_snapshot()
|
|
if snap:
|
|
log.info("WS sending initial snapshot devices=%d", len(snap))
|
|
await websocket.send(json.dumps(snap))
|
|
|
|
async for _ in websocket:
|
|
pass
|
|
except Exception:
|
|
log.exception("WS handler error")
|
|
finally:
|
|
clients.discard(websocket)
|
|
log.info("WS disconnected total=%d", len(clients))
|
|
|
|
|
|
async def post_slack(text: str):
|
|
"""Send a notification to Slack using the webhook."""
|
|
if not settings.slack_webhook_url:
|
|
log.debug("Slack webhook URL not configured; skipping")
|
|
return
|
|
try:
|
|
async with aiohttp.ClientSession() as session:
|
|
resp = await session.post(
|
|
settings.slack_webhook_url, json={"text": text}, timeout=10
|
|
)
|
|
body = await resp.text()
|
|
log.info("Slack post status=%s", resp.status)
|
|
log.debug("Slack response body=%s", body)
|
|
except aiohttp.ClientError:
|
|
log.exception("Slack network error")
|
|
except Exception:
|
|
log.exception("Slack unexpected error")
|
|
|
|
|
|
# Keep strong refs to background tasks; without this, the GC can collect a
|
|
# task before it runs and exceptions vanish into "task was never awaited".
|
|
_background_tasks: set[asyncio.Task] = set()
|
|
|
|
|
|
def _schedule(coro):
|
|
"""Fire-and-forget helper that retains the task and logs exceptions."""
|
|
task = asyncio.create_task(coro)
|
|
_background_tasks.add(task)
|
|
task.add_done_callback(_background_tasks.discard)
|
|
|
|
|
|
async def _broadcast(msg: str):
|
|
"""Send a message to every WS client in parallel; drop dead ones."""
|
|
snapshot = list(clients)
|
|
results = await asyncio.gather(
|
|
*(ws.send(msg) for ws in snapshot),
|
|
return_exceptions=True,
|
|
)
|
|
for ws, result in zip(snapshot, results):
|
|
if isinstance(result, Exception):
|
|
log.warning("WS dropping client: %s", result)
|
|
clients.discard(ws)
|
|
|
|
|
|
async def mqtt_loop():
|
|
"""Maintain an MQTT subscription, reconnecting with exponential backoff."""
|
|
client = MQTTClient("backend-bridge")
|
|
disconnected = asyncio.Event()
|
|
|
|
# ---------------- Callbacks (must be sync defs) ----------------
|
|
def on_connect(c, flags, rc, properties):
|
|
log.info("MQTT connected rc=%s flags=%s props=%s", rc, flags, properties)
|
|
try:
|
|
c.subscribe(settings.topic, qos=1)
|
|
log.info("MQTT subscribe requested topic=%s", settings.topic)
|
|
except Exception:
|
|
log.exception("MQTT subscribe failed")
|
|
|
|
def on_disconnect(_c, packet, exc=None):
|
|
log.warning("MQTT disconnected packet=%s exc=%s", packet, exc)
|
|
disconnected.set()
|
|
|
|
def on_subscribe(_c, mid, qos, properties):
|
|
log.info("MQTT subscribed mid=%s qos=%s props=%s", mid, qos, properties)
|
|
|
|
def on_message(_c, topic, payload, _qos, _properties):
|
|
"""Handle incoming MQTT messages from devices."""
|
|
raw = payload.decode() if isinstance(payload, (bytes, bytearray)) else payload
|
|
if len(raw) > settings.max_payload_bytes:
|
|
log.warning("MQTT payload too large topic=%s bytes=%d", topic, len(raw))
|
|
return
|
|
log.debug("MQTT raw topic=%s payload=%s", topic, raw[:100])
|
|
|
|
metrics = parse_metrics(raw)
|
|
if metrics is None:
|
|
return
|
|
|
|
device_id = metrics.device_id
|
|
if device_id not in device_state and len(device_state) >= settings.max_devices:
|
|
log.warning("MQTT device cap reached cap=%d dropping=%s",
|
|
settings.max_devices, device_id)
|
|
return
|
|
data = metrics.model_dump()
|
|
device_state[device_id] = data
|
|
last_seen[device_id] = time.time()
|
|
|
|
# broadcast to WebSocket clients
|
|
if clients:
|
|
msg = json.dumps({device_id: data})
|
|
log.debug("WS broadcasting device=%s clients=%d", device_id, len(clients))
|
|
_schedule(_broadcast(msg))
|
|
|
|
# 🚨 Slack alert if CPU too high (throttled per device).
|
|
now = time.time()
|
|
if should_alert(metrics.cpu_percent, device_id, now, _last_alert,
|
|
settings.alert_cooldown, settings.cpu_alert_th):
|
|
_last_alert[device_id] = now
|
|
text = (
|
|
f":rotating_light: CPU ALERT at {device_id} — "
|
|
f"CPU {metrics.cpu_percent:.1f}%, RAM {metrics.mem_percent}%"
|
|
)
|
|
log.info("Alert triggered device=%s cpu=%.1f",
|
|
device_id, metrics.cpu_percent)
|
|
_schedule(post_slack(text))
|
|
|
|
# Attach callbacks
|
|
client.on_connect = on_connect
|
|
client.on_disconnect = on_disconnect
|
|
client.on_subscribe = on_subscribe
|
|
client.on_message = on_message
|
|
|
|
# TLS context is reusable across reconnects; build it once.
|
|
ssl_ctx = ssl.create_default_context()
|
|
if settings.ca_cert:
|
|
ssl_ctx.load_verify_locations(settings.ca_cert)
|
|
if settings.client_cert and settings.client_key:
|
|
ssl_ctx.load_cert_chain(
|
|
certfile=settings.client_cert,
|
|
keyfile=settings.client_key,
|
|
)
|
|
|
|
backoff = settings.mqtt_backoff_min
|
|
while True:
|
|
try:
|
|
log.info("MQTT connecting host=%s port=%s",
|
|
settings.mqtt_broker, settings.mqtt_port)
|
|
disconnected.clear()
|
|
await client.connect(settings.mqtt_broker, port=settings.mqtt_port,
|
|
ssl=ssl_ctx)
|
|
log.info("MQTT connect() returned")
|
|
backoff = settings.mqtt_backoff_min
|
|
await disconnected.wait()
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
log.exception("MQTT connect/run failed")
|
|
|
|
log.info("MQTT reconnecting in %ds", backoff)
|
|
try:
|
|
await asyncio.sleep(backoff)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
backoff = next_backoff(backoff, settings.mqtt_backoff_max)
|
|
|
|
|
|
async def slack_summary_loop():
|
|
"""Periodically send a summary of active devices to Slack."""
|
|
while True:
|
|
await asyncio.sleep(settings.summary_interval)
|
|
snapshot = active_snapshot()
|
|
if not snapshot:
|
|
log.debug("Slack no active devices to summarize")
|
|
continue
|
|
|
|
lines = ["Metrics summary (last interval):", ""]
|
|
for did, data in snapshot.items():
|
|
lines.append(
|
|
f"{did:15} | CPU {data.get('cpu_percent', '?')}% "
|
|
f"| RAM {data.get('mem_percent', '?')}% "
|
|
f"| Disk {data.get('disk_percent', '?')}% "
|
|
f"| GPU {data.get('gpu_percent', '?')}"
|
|
)
|
|
|
|
formatted = "```\n" + "\n".join(lines) + "\n```"
|
|
log.info("Slack sending periodic summary devices=%d", len(snapshot))
|
|
await post_slack(formatted)
|
|
|
|
|
|
async def main():
|
|
"""Start WebSocket server, MQTT loop, and Slack summary loop."""
|
|
async with serve(ws_handler, settings.ws_host, settings.ws_port):
|
|
log.info("WS server listening url=ws://%s:%s",
|
|
settings.ws_host, settings.ws_port)
|
|
await asyncio.gather(
|
|
mqtt_loop(),
|
|
slack_summary_loop(),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
log.info("Backend stopped by user")
|