feat(config): typed settings via pydantic-settings

Replace the scattered os.getenv() + int()/float() pattern with a
BaseSettings class on both modules. Wins:

  - bad config now fails at import with a readable pydantic error
    (WS_PORT=abc no longer produces a ValueError stack from inside
    main()); ports are bounded to [1, 65535], cpu_alert_th to [0,100],
    backoff_min/interval to >= 1,
  - .env loading moves into pydantic-settings (env_file in
    SettingsConfigDict), so the manual load_dotenv() call is gone,
  - every callback now reads from a single ``settings`` instance, so
    runtime overrides are possible (tests use monkeypatch on
    backend.settings instead of patching module-level constants).

Test for ws_token is updated to patch backend.settings.ws_auth_token
rather than the old WS_AUTH_TOKEN module constant; the contract is
unchanged so all 55 tests still pass.

Pydantic stack pinned: pydantic==2.13.4, pydantic-core==2.46.4,
pydantic-settings==2.14.1 (plus annotated-types and typing-inspection
as transitives). pip-audit remains clean.
This commit is contained in:
authentik Default Admin 2026-05-17 17:09:50 +00:00
parent 06c665843d
commit adf6a7a1ce
4 changed files with 139 additions and 110 deletions

View file

@ -13,59 +13,68 @@ import re
import secrets
import ssl
import time
from typing import Optional
from urllib.parse import parse_qs, urlsplit
import aiohttp
from gmqtt import Client as MQTTClient
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from websockets import serve # websockets >= 10
from dotenv import load_dotenv
# Load environment variables from .env
load_dotenv()
# Structured logging. LOG_LEVEL accepts DEBUG / INFO / WARNING / ERROR.
class BackendSettings(BaseSettings):
"""Typed configuration. Values are read from environment variables or
a ``.env`` file at startup and validated by pydantic. A bad value
(``WS_PORT=abc``, ``CPU_ALERT_TH=200``) fails fast with a readable
message instead of crashing deep inside a callback.
"""
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
# MQTT
mqtt_broker: str = ""
mqtt_port: int = Field(default=8883, ge=1, le=65535)
topic: str = "devices/+/metrics"
ca_cert: Optional[str] = None
client_cert: Optional[str] = None
client_key: Optional[str] = None
mqtt_backoff_min: int = Field(default=1, ge=1)
mqtt_backoff_max: int = Field(default=60, ge=1)
# WebSocket
ws_host: str = "0.0.0.0"
ws_port: int = Field(default=6789, ge=1, le=65535)
# Shared token required as ?token=... on the WS handshake. Empty = dev.
ws_auth_token: str = ""
# App
prune_seconds: int = Field(default=30, ge=0)
cpu_alert_th: float = Field(default=90.0, ge=0, le=100)
slack_webhook_url: Optional[str] = None
summary_interval: int = Field(default=60, ge=1)
# Hardening against untrusted MQTT input.
max_payload_bytes: int = Field(default=16 * 1024, ge=64)
max_devices: int = Field(default=1000, ge=1)
alert_cooldown: int = Field(default=60, ge=0)
log_level: str = "INFO"
settings = BackendSettings()
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO").upper(),
level=settings.log_level.upper(),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
log = logging.getLogger("edgewatch.backend")
# MQTT configs
MQTT_BROKER = os.getenv("MQTT_BROKER")
MQTT_PORT = int(os.getenv("MQTT_PORT", "8883"))
TOPIC = os.getenv("TOPIC", "devices/+/metrics")
# WebSocket configs
WS_HOST = os.getenv("WS_HOST", "0.0.0.0")
WS_PORT = int(os.getenv("WS_PORT", "6789"))
# Shared token required as ?token=... on the WS handshake. When empty the
# server accepts any client (dev mode); set this in production.
WS_AUTH_TOKEN = os.getenv("WS_AUTH_TOKEN", "")
# App configs
PRUNE_SECONDS = int(os.getenv("PRUNE_SECONDS", "30"))
CPU_ALERT_TH = float(os.getenv("CPU_ALERT_TH", "90"))
SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
SUMMARY_INTERVAL = int(os.getenv("SUMMARY_INTERVAL", "60"))
# Hard limits on untrusted MQTT input.
MAX_PAYLOAD_BYTES = int(os.getenv("MAX_PAYLOAD_BYTES", str(16 * 1024)))
MAX_DEVICES = int(os.getenv("MAX_DEVICES", "1000"))
_DEVICE_ID_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
# Slack alert throttling: re-alert per device at most once per cooldown.
ALERT_COOLDOWN = int(os.getenv("ALERT_COOLDOWN", "60"))
_last_alert: dict[str, float] = {}
# MQTT reconnect backoff bounds (seconds).
MQTT_BACKOFF_MIN = int(os.getenv("MQTT_BACKOFF_MIN", "1"))
MQTT_BACKOFF_MAX = int(os.getenv("MQTT_BACKOFF_MAX", "60"))
def next_backoff(current, ceiling):
"""Double the backoff but never exceed ``ceiling``. Pure for testability."""
return min(current * 2, ceiling)
clients = set()
device_state = {}
last_seen = {}
@ -84,8 +93,8 @@ def filter_active(state, seen, now, prune_seconds):
def active_snapshot():
"""Return only devices that have updated within PRUNE_SECONDS."""
return filter_active(device_state, last_seen, time.time(), PRUNE_SECONDS)
"""Return only devices that have updated within ``settings.prune_seconds``."""
return filter_active(device_state, last_seen, time.time(), settings.prune_seconds)
def validate_device_id(raw):
@ -114,21 +123,27 @@ def should_alert(cpu, device_id, now, last_alert_map, cooldown, threshold):
)
def next_backoff(current, ceiling):
"""Double the backoff but never exceed ``ceiling``. Pure for testability."""
return min(current * 2, ceiling)
def _ws_token_ok(websocket) -> bool:
"""Return True if the handshake carries a valid shared token.
When WS_AUTH_TOKEN is empty the server is in open mode and accepts any
client (useful for local dev). When set, the client must connect to
``ws://host:port/?token=<value>`` and the comparison is constant-time.
When ``settings.ws_auth_token`` is empty the server is in open mode and
accepts any client (useful for local dev). When set, the client must
connect to ``ws://host:port/?token=<value>`` and the comparison is
constant-time.
"""
if not WS_AUTH_TOKEN:
if not settings.ws_auth_token:
return True
# websockets >= 11 exposes the full request-target (path + query) via
# connection.request.path; older versions kept it on connection.path.
request = getattr(websocket, "request", None)
target = getattr(request, "path", None) or getattr(websocket, "path", "")
token = (parse_qs(urlsplit(target).query).get("token") or [""])[0]
return bool(token) and secrets.compare_digest(token, WS_AUTH_TOKEN)
return bool(token) and secrets.compare_digest(token, settings.ws_auth_token)
async def ws_handler(websocket):
@ -156,13 +171,13 @@ async def ws_handler(websocket):
async def post_slack(text: str):
"""Send a notification to Slack using the webhook."""
if not SLACK_WEBHOOK_URL:
if not settings.slack_webhook_url:
log.debug("Slack webhook URL not configured; skipping")
return
try:
async with aiohttp.ClientSession() as session:
resp = await session.post(
SLACK_WEBHOOK_URL, json={"text": text}, timeout=10
settings.slack_webhook_url, json={"text": text}, timeout=10
)
body = await resp.text()
log.info("Slack post status=%s", resp.status)
@ -199,13 +214,7 @@ async def _broadcast(msg: str):
async def mqtt_loop():
"""Maintain an MQTT subscription, reconnecting with exponential backoff.
Previously a single connect() failure (or any later disconnect) left
the loop idle in an "asyncio.sleep(1)" forever. Now the loop wraps
connect + idle-while-connected in retry with capped exponential
backoff so a broker hiccup self-heals.
"""
"""Maintain an MQTT subscription, reconnecting with exponential backoff."""
client = MQTTClient("backend-bridge")
disconnected = asyncio.Event()
@ -213,8 +222,8 @@ async def mqtt_loop():
def on_connect(c, flags, rc, properties):
log.info("MQTT connected rc=%s flags=%s props=%s", rc, flags, properties)
try:
c.subscribe(TOPIC, qos=1)
log.info("MQTT subscribe requested topic=%s", TOPIC)
c.subscribe(settings.topic, qos=1)
log.info("MQTT subscribe requested topic=%s", settings.topic)
except Exception:
log.exception("MQTT subscribe failed")
@ -229,7 +238,7 @@ async def mqtt_loop():
"""Handle incoming MQTT messages from devices."""
try:
raw = payload.decode() if isinstance(payload, (bytes, bytearray)) else payload
if len(raw) > MAX_PAYLOAD_BYTES:
if len(raw) > settings.max_payload_bytes:
log.warning("MQTT payload too large topic=%s bytes=%d", topic, len(raw))
return
log.debug("MQTT raw topic=%s payload=%s", topic, raw[:100])
@ -243,9 +252,9 @@ async def mqtt_loop():
log.warning("MQTT invalid device_id topic=%s raw=%r",
topic, data.get("device_id"))
return
if device_id not in device_state and len(device_state) >= MAX_DEVICES:
if device_id not in device_state and len(device_state) >= settings.max_devices:
log.warning("MQTT device cap reached cap=%d dropping=%s",
MAX_DEVICES, device_id)
settings.max_devices, device_id)
return
device_state[device_id] = data
last_seen[device_id] = time.time()
@ -259,7 +268,8 @@ async def mqtt_loop():
# 🚨 Slack alert if CPU too high (throttled per device).
cpu = data.get("cpu_percent")
now = time.time()
if should_alert(cpu, device_id, now, _last_alert, ALERT_COOLDOWN, CPU_ALERT_TH):
if should_alert(cpu, device_id, now, _last_alert,
settings.alert_cooldown, settings.cpu_alert_th):
_last_alert[device_id] = now
text = (
f":rotating_light: CPU ALERT at {device_id}"
@ -276,20 +286,24 @@ async def mqtt_loop():
# TLS context is reusable across reconnects; build it once.
ssl_ctx = ssl.create_default_context()
ssl_ctx.load_verify_locations(os.getenv("CA_CERT"))
ssl_ctx.load_cert_chain(
certfile=os.getenv("CLIENT_CERT"),
keyfile=os.getenv("CLIENT_KEY"),
)
if settings.ca_cert:
ssl_ctx.load_verify_locations(settings.ca_cert)
if settings.client_cert and settings.client_key:
ssl_ctx.load_cert_chain(
certfile=settings.client_cert,
keyfile=settings.client_key,
)
backoff = MQTT_BACKOFF_MIN
backoff = settings.mqtt_backoff_min
while True:
try:
log.info("MQTT connecting host=%s port=%s", MQTT_BROKER, MQTT_PORT)
log.info("MQTT connecting host=%s port=%s",
settings.mqtt_broker, settings.mqtt_port)
disconnected.clear()
await client.connect(MQTT_BROKER, port=MQTT_PORT, ssl=ssl_ctx)
await client.connect(settings.mqtt_broker, port=settings.mqtt_port,
ssl=ssl_ctx)
log.info("MQTT connect() returned")
backoff = MQTT_BACKOFF_MIN # reset after a successful connect
backoff = settings.mqtt_backoff_min
await disconnected.wait()
except asyncio.CancelledError:
raise
@ -301,13 +315,13 @@ async def mqtt_loop():
await asyncio.sleep(backoff)
except asyncio.CancelledError:
raise
backoff = next_backoff(backoff, MQTT_BACKOFF_MAX)
backoff = next_backoff(backoff, settings.mqtt_backoff_max)
async def slack_summary_loop():
"""Periodically send a summary of active devices to Slack."""
while True:
await asyncio.sleep(SUMMARY_INTERVAL)
await asyncio.sleep(settings.summary_interval)
snapshot = active_snapshot()
if not snapshot:
log.debug("Slack no active devices to summarize")
@ -329,8 +343,9 @@ async def slack_summary_loop():
async def main():
"""Start WebSocket server, MQTT loop, and Slack summary loop."""
async with serve(ws_handler, WS_HOST, WS_PORT):
log.info("WS server listening url=ws://%s:%s", WS_HOST, WS_PORT)
async with serve(ws_handler, settings.ws_host, settings.ws_port):
log.info("WS server listening url=ws://%s:%s",
settings.ws_host, settings.ws_port)
await asyncio.gather(
mqtt_loop(),
slack_summary_loop(),

View file

@ -1,36 +1,47 @@
import asyncio
import json
import logging
import os
import ssl
import subprocess
import time
from typing import Optional
import psutil
from gmqtt import Client as MQTTClient
from dotenv import load_dotenv
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
# Load environment variables
load_dotenv()
class AgentSettings(BaseSettings):
"""Typed configuration for the edge-device agent."""
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
device_id: str = "device-iot-001"
mqtt_broker: str = ""
mqtt_port: int = Field(default=8883, ge=1, le=65535)
ca_cert: str = "AmazonRootCA1.pem"
client_cert: str = "device.cert.pem"
client_key: str = "device.private.key"
interval: int = Field(default=10, ge=1)
agent_cpu_warn: float = Field(default=90.0, ge=0, le=100)
log_level: str = "INFO"
settings = AgentSettings()
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO").upper(),
level=settings.log_level.upper(),
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
log = logging.getLogger("edgewatch.agent")
DEVICE_ID = os.getenv("DEVICE_ID", "device-iot-001")
MQTT_BROKER = os.getenv("MQTT_BROKER") # AWS IoT endpoint (xxx-ats.iot.<region>.amazonaws.com)
MQTT_PORT = int(os.getenv("MQTT_PORT", "8883"))
CA_CERT = os.getenv("CA_CERT", "AmazonRootCA1.pem")
CLIENT_CERT = os.getenv("CLIENT_CERT", "device.cert.pem")
CLIENT_KEY = os.getenv("CLIENT_KEY", "device.private.key")
INTERVAL = int(os.getenv("INTERVAL", "10"))
# Local warn threshold for stdout logs (Slack alerting lives in the backend).
AGENT_CPU_WARN = float(os.getenv("AGENT_CPU_WARN", "90"))
TOPIC = f"devices/{DEVICE_ID}/metrics"
TOPIC = f"devices/{settings.device_id}/metrics"
# Process reference to measure agent overhead
process = psutil.Process(os.getpid())
process = psutil.Process()
def collect_metrics():
@ -48,7 +59,7 @@ def collect_metrics():
]
).decode().strip()
gpu_util, _, _ = out.split(",")
gpu = float(gpu_util)
gpu: Optional[float] = float(gpu_util)
except Exception: # noqa: BLE001
gpu = None # No GPU available
@ -56,7 +67,7 @@ def collect_metrics():
agent_cpu = process.cpu_percent(interval=None) # %
return {
"device_id": DEVICE_ID,
"device_id": settings.device_id,
"timestamp": int(time.time()),
"cpu_percent": cpu,
"mem_percent": mem,
@ -69,43 +80,41 @@ def collect_metrics():
async def main():
"""Main loop: connect to MQTT broker and periodically publish metrics."""
client = MQTTClient(DEVICE_ID)
client = MQTTClient(settings.device_id)
# Event handlers (unused args prefixed with _ to satisfy linting)
def on_connect(_client, _flags, _rc, _properties):
log.info("agent connected device=%s", DEVICE_ID)
log.info("agent connected device=%s", settings.device_id)
def on_disconnect(_client, _packet, _exc=None):
log.warning("agent disconnected device=%s", DEVICE_ID)
log.warning("agent disconnected device=%s", settings.device_id)
client.on_connect = on_connect
client.on_disconnect = on_disconnect
# TLS configuration
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
ssl_ctx.load_verify_locations(CA_CERT)
ssl_ctx.load_cert_chain(certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
ssl_ctx.load_verify_locations(settings.ca_cert)
ssl_ctx.load_cert_chain(certfile=settings.client_cert,
keyfile=settings.client_key)
# Connect to AWS IoT
await client.connect(MQTT_BROKER, MQTT_PORT, ssl=ssl_ctx)
await client.connect(settings.mqtt_broker, settings.mqtt_port, ssl=ssl_ctx)
try:
while True:
metrics = collect_metrics()
if metrics["cpu_percent"] > AGENT_CPU_WARN:
if metrics["cpu_percent"] > settings.agent_cpu_warn:
log.warning("local CPU warn device=%s cpu=%.1f",
DEVICE_ID, metrics["cpu_percent"])
settings.device_id, metrics["cpu_percent"])
log.debug("agent self_cpu=%.1f self_mem_mb=%.2f",
metrics["agent_cpu_percent"], metrics["agent_mem_mb"])
client.publish(TOPIC, json.dumps(metrics), qos=1, retain=False)
await asyncio.sleep(INTERVAL)
await asyncio.sleep(settings.interval)
except asyncio.CancelledError:
log.info("agent cancelled device=%s", DEVICE_ID)
log.info("agent cancelled device=%s", settings.device_id)
finally:
await client.disconnect()

View file

@ -1,6 +1,7 @@
aiohappyeyeballs==2.6.1
aiohttp==3.13.4
aiosignal==1.4.0
annotated-types==0.7.0
attrs==25.3.0
frozenlist==1.7.0
gmqtt==0.7.0
@ -8,7 +9,11 @@ idna==3.10
multidict==6.6.4
propcache==0.3.2
psutil==7.0.0
pydantic==2.13.4
pydantic-core==2.46.4
pydantic-settings==2.14.1
python-dotenv==1.2.2
typing-inspection==0.4.2
typing_extensions==4.15.0
websockets==15.0.1
yarl==1.20.1

View file

@ -18,28 +18,28 @@ def _fake_ws(target: str):
def test_open_mode_accepts_any_handshake(monkeypatch):
# Empty WS_AUTH_TOKEN = open mode (dev-friendly default).
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "")
monkeypatch.setattr(backend.settings, "ws_auth_token", "")
assert backend._ws_token_ok(_fake_ws("/")) is True
assert backend._ws_token_ok(_fake_ws("/?token=anything")) is True
def test_rejects_missing_token(monkeypatch):
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
assert backend._ws_token_ok(_fake_ws("/")) is False
def test_rejects_wrong_token(monkeypatch):
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
assert backend._ws_token_ok(_fake_ws("/?token=wrong")) is False
def test_accepts_correct_token(monkeypatch):
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
assert backend._ws_token_ok(_fake_ws("/?token=s3cret")) is True
def test_accepts_token_with_extra_query_params(monkeypatch):
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
assert backend._ws_token_ok(_fake_ws("/?foo=bar&token=s3cret&x=1")) is True
@ -47,5 +47,5 @@ def test_rejects_empty_token_value(monkeypatch):
# ?token= with no value should never match a configured secret,
# otherwise compare_digest("", "") would let everyone in if
# WS_AUTH_TOKEN were ever accidentally set to "".
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
assert backend._ws_token_ok(_fake_ws("/?token=")) is False