feat(config): typed settings via pydantic-settings
Replace the scattered os.getenv() + int()/float() pattern with a
BaseSettings class on both modules. Wins:
- bad config now fails at import with a readable pydantic error
(WS_PORT=abc no longer produces a ValueError stack from inside
main()); ports are bounded to [1, 65535], cpu_alert_th to [0,100],
backoff_min/interval to >= 1,
- .env loading moves into pydantic-settings (env_file in
SettingsConfigDict), so the manual load_dotenv() call is gone,
- every callback now reads from a single ``settings`` instance, so
runtime overrides are possible (tests use monkeypatch on
backend.settings instead of patching module-level constants).
Test for ws_token is updated to patch backend.settings.ws_auth_token
rather than the old WS_AUTH_TOKEN module constant; the contract is
unchanged so all 55 tests still pass.
Pydantic stack pinned: pydantic==2.13.4, pydantic-core==2.46.4,
pydantic-settings==2.14.1 (plus annotated-types and typing-inspection
as transitives). pip-audit remains clean.
This commit is contained in:
parent
06c665843d
commit
adf6a7a1ce
4 changed files with 139 additions and 110 deletions
161
backend.py
161
backend.py
|
|
@ -13,59 +13,68 @@ import re
|
|||
import secrets
|
||||
import ssl
|
||||
import time
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qs, urlsplit
|
||||
|
||||
import aiohttp
|
||||
from gmqtt import Client as MQTTClient
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
from websockets import serve # websockets >= 10
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env
|
||||
load_dotenv()
|
||||
|
||||
# Structured logging. LOG_LEVEL accepts DEBUG / INFO / WARNING / ERROR.
|
||||
class BackendSettings(BaseSettings):
|
||||
"""Typed configuration. Values are read from environment variables or
|
||||
a ``.env`` file at startup and validated by pydantic. A bad value
|
||||
(``WS_PORT=abc``, ``CPU_ALERT_TH=200``) fails fast with a readable
|
||||
message instead of crashing deep inside a callback.
|
||||
"""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
||||
)
|
||||
|
||||
# MQTT
|
||||
mqtt_broker: str = ""
|
||||
mqtt_port: int = Field(default=8883, ge=1, le=65535)
|
||||
topic: str = "devices/+/metrics"
|
||||
ca_cert: Optional[str] = None
|
||||
client_cert: Optional[str] = None
|
||||
client_key: Optional[str] = None
|
||||
mqtt_backoff_min: int = Field(default=1, ge=1)
|
||||
mqtt_backoff_max: int = Field(default=60, ge=1)
|
||||
|
||||
# WebSocket
|
||||
ws_host: str = "0.0.0.0"
|
||||
ws_port: int = Field(default=6789, ge=1, le=65535)
|
||||
# Shared token required as ?token=... on the WS handshake. Empty = dev.
|
||||
ws_auth_token: str = ""
|
||||
|
||||
# App
|
||||
prune_seconds: int = Field(default=30, ge=0)
|
||||
cpu_alert_th: float = Field(default=90.0, ge=0, le=100)
|
||||
slack_webhook_url: Optional[str] = None
|
||||
summary_interval: int = Field(default=60, ge=1)
|
||||
|
||||
# Hardening against untrusted MQTT input.
|
||||
max_payload_bytes: int = Field(default=16 * 1024, ge=64)
|
||||
max_devices: int = Field(default=1000, ge=1)
|
||||
alert_cooldown: int = Field(default=60, ge=0)
|
||||
|
||||
log_level: str = "INFO"
|
||||
|
||||
|
||||
settings = BackendSettings()
|
||||
|
||||
logging.basicConfig(
|
||||
level=os.getenv("LOG_LEVEL", "INFO").upper(),
|
||||
level=settings.log_level.upper(),
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
log = logging.getLogger("edgewatch.backend")
|
||||
|
||||
# MQTT configs
|
||||
MQTT_BROKER = os.getenv("MQTT_BROKER")
|
||||
MQTT_PORT = int(os.getenv("MQTT_PORT", "8883"))
|
||||
TOPIC = os.getenv("TOPIC", "devices/+/metrics")
|
||||
|
||||
# WebSocket configs
|
||||
WS_HOST = os.getenv("WS_HOST", "0.0.0.0")
|
||||
WS_PORT = int(os.getenv("WS_PORT", "6789"))
|
||||
# Shared token required as ?token=... on the WS handshake. When empty the
|
||||
# server accepts any client (dev mode); set this in production.
|
||||
WS_AUTH_TOKEN = os.getenv("WS_AUTH_TOKEN", "")
|
||||
|
||||
# App configs
|
||||
PRUNE_SECONDS = int(os.getenv("PRUNE_SECONDS", "30"))
|
||||
CPU_ALERT_TH = float(os.getenv("CPU_ALERT_TH", "90"))
|
||||
SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")
|
||||
SUMMARY_INTERVAL = int(os.getenv("SUMMARY_INTERVAL", "60"))
|
||||
|
||||
# Hard limits on untrusted MQTT input.
|
||||
MAX_PAYLOAD_BYTES = int(os.getenv("MAX_PAYLOAD_BYTES", str(16 * 1024)))
|
||||
MAX_DEVICES = int(os.getenv("MAX_DEVICES", "1000"))
|
||||
_DEVICE_ID_RE = re.compile(r"^[A-Za-z0-9._-]{1,64}$")
|
||||
|
||||
# Slack alert throttling: re-alert per device at most once per cooldown.
|
||||
ALERT_COOLDOWN = int(os.getenv("ALERT_COOLDOWN", "60"))
|
||||
_last_alert: dict[str, float] = {}
|
||||
|
||||
# MQTT reconnect backoff bounds (seconds).
|
||||
MQTT_BACKOFF_MIN = int(os.getenv("MQTT_BACKOFF_MIN", "1"))
|
||||
MQTT_BACKOFF_MAX = int(os.getenv("MQTT_BACKOFF_MAX", "60"))
|
||||
|
||||
|
||||
def next_backoff(current, ceiling):
|
||||
"""Double the backoff but never exceed ``ceiling``. Pure for testability."""
|
||||
return min(current * 2, ceiling)
|
||||
|
||||
clients = set()
|
||||
device_state = {}
|
||||
last_seen = {}
|
||||
|
|
@ -84,8 +93,8 @@ def filter_active(state, seen, now, prune_seconds):
|
|||
|
||||
|
||||
def active_snapshot():
|
||||
"""Return only devices that have updated within PRUNE_SECONDS."""
|
||||
return filter_active(device_state, last_seen, time.time(), PRUNE_SECONDS)
|
||||
"""Return only devices that have updated within ``settings.prune_seconds``."""
|
||||
return filter_active(device_state, last_seen, time.time(), settings.prune_seconds)
|
||||
|
||||
|
||||
def validate_device_id(raw):
|
||||
|
|
@ -114,21 +123,27 @@ def should_alert(cpu, device_id, now, last_alert_map, cooldown, threshold):
|
|||
)
|
||||
|
||||
|
||||
def next_backoff(current, ceiling):
|
||||
"""Double the backoff but never exceed ``ceiling``. Pure for testability."""
|
||||
return min(current * 2, ceiling)
|
||||
|
||||
|
||||
def _ws_token_ok(websocket) -> bool:
|
||||
"""Return True if the handshake carries a valid shared token.
|
||||
|
||||
When WS_AUTH_TOKEN is empty the server is in open mode and accepts any
|
||||
client (useful for local dev). When set, the client must connect to
|
||||
``ws://host:port/?token=<value>`` and the comparison is constant-time.
|
||||
When ``settings.ws_auth_token`` is empty the server is in open mode and
|
||||
accepts any client (useful for local dev). When set, the client must
|
||||
connect to ``ws://host:port/?token=<value>`` and the comparison is
|
||||
constant-time.
|
||||
"""
|
||||
if not WS_AUTH_TOKEN:
|
||||
if not settings.ws_auth_token:
|
||||
return True
|
||||
# websockets >= 11 exposes the full request-target (path + query) via
|
||||
# connection.request.path; older versions kept it on connection.path.
|
||||
request = getattr(websocket, "request", None)
|
||||
target = getattr(request, "path", None) or getattr(websocket, "path", "")
|
||||
token = (parse_qs(urlsplit(target).query).get("token") or [""])[0]
|
||||
return bool(token) and secrets.compare_digest(token, WS_AUTH_TOKEN)
|
||||
return bool(token) and secrets.compare_digest(token, settings.ws_auth_token)
|
||||
|
||||
|
||||
async def ws_handler(websocket):
|
||||
|
|
@ -156,13 +171,13 @@ async def ws_handler(websocket):
|
|||
|
||||
async def post_slack(text: str):
|
||||
"""Send a notification to Slack using the webhook."""
|
||||
if not SLACK_WEBHOOK_URL:
|
||||
if not settings.slack_webhook_url:
|
||||
log.debug("Slack webhook URL not configured; skipping")
|
||||
return
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
resp = await session.post(
|
||||
SLACK_WEBHOOK_URL, json={"text": text}, timeout=10
|
||||
settings.slack_webhook_url, json={"text": text}, timeout=10
|
||||
)
|
||||
body = await resp.text()
|
||||
log.info("Slack post status=%s", resp.status)
|
||||
|
|
@ -199,13 +214,7 @@ async def _broadcast(msg: str):
|
|||
|
||||
|
||||
async def mqtt_loop():
|
||||
"""Maintain an MQTT subscription, reconnecting with exponential backoff.
|
||||
|
||||
Previously a single connect() failure (or any later disconnect) left
|
||||
the loop idle in an "asyncio.sleep(1)" forever. Now the loop wraps
|
||||
connect + idle-while-connected in retry with capped exponential
|
||||
backoff so a broker hiccup self-heals.
|
||||
"""
|
||||
"""Maintain an MQTT subscription, reconnecting with exponential backoff."""
|
||||
client = MQTTClient("backend-bridge")
|
||||
disconnected = asyncio.Event()
|
||||
|
||||
|
|
@ -213,8 +222,8 @@ async def mqtt_loop():
|
|||
def on_connect(c, flags, rc, properties):
|
||||
log.info("MQTT connected rc=%s flags=%s props=%s", rc, flags, properties)
|
||||
try:
|
||||
c.subscribe(TOPIC, qos=1)
|
||||
log.info("MQTT subscribe requested topic=%s", TOPIC)
|
||||
c.subscribe(settings.topic, qos=1)
|
||||
log.info("MQTT subscribe requested topic=%s", settings.topic)
|
||||
except Exception:
|
||||
log.exception("MQTT subscribe failed")
|
||||
|
||||
|
|
@ -229,7 +238,7 @@ async def mqtt_loop():
|
|||
"""Handle incoming MQTT messages from devices."""
|
||||
try:
|
||||
raw = payload.decode() if isinstance(payload, (bytes, bytearray)) else payload
|
||||
if len(raw) > MAX_PAYLOAD_BYTES:
|
||||
if len(raw) > settings.max_payload_bytes:
|
||||
log.warning("MQTT payload too large topic=%s bytes=%d", topic, len(raw))
|
||||
return
|
||||
log.debug("MQTT raw topic=%s payload=%s", topic, raw[:100])
|
||||
|
|
@ -243,9 +252,9 @@ async def mqtt_loop():
|
|||
log.warning("MQTT invalid device_id topic=%s raw=%r",
|
||||
topic, data.get("device_id"))
|
||||
return
|
||||
if device_id not in device_state and len(device_state) >= MAX_DEVICES:
|
||||
if device_id not in device_state and len(device_state) >= settings.max_devices:
|
||||
log.warning("MQTT device cap reached cap=%d dropping=%s",
|
||||
MAX_DEVICES, device_id)
|
||||
settings.max_devices, device_id)
|
||||
return
|
||||
device_state[device_id] = data
|
||||
last_seen[device_id] = time.time()
|
||||
|
|
@ -259,7 +268,8 @@ async def mqtt_loop():
|
|||
# 🚨 Slack alert if CPU too high (throttled per device).
|
||||
cpu = data.get("cpu_percent")
|
||||
now = time.time()
|
||||
if should_alert(cpu, device_id, now, _last_alert, ALERT_COOLDOWN, CPU_ALERT_TH):
|
||||
if should_alert(cpu, device_id, now, _last_alert,
|
||||
settings.alert_cooldown, settings.cpu_alert_th):
|
||||
_last_alert[device_id] = now
|
||||
text = (
|
||||
f":rotating_light: CPU ALERT at {device_id} — "
|
||||
|
|
@ -276,20 +286,24 @@ async def mqtt_loop():
|
|||
|
||||
# TLS context is reusable across reconnects; build it once.
|
||||
ssl_ctx = ssl.create_default_context()
|
||||
ssl_ctx.load_verify_locations(os.getenv("CA_CERT"))
|
||||
ssl_ctx.load_cert_chain(
|
||||
certfile=os.getenv("CLIENT_CERT"),
|
||||
keyfile=os.getenv("CLIENT_KEY"),
|
||||
)
|
||||
if settings.ca_cert:
|
||||
ssl_ctx.load_verify_locations(settings.ca_cert)
|
||||
if settings.client_cert and settings.client_key:
|
||||
ssl_ctx.load_cert_chain(
|
||||
certfile=settings.client_cert,
|
||||
keyfile=settings.client_key,
|
||||
)
|
||||
|
||||
backoff = MQTT_BACKOFF_MIN
|
||||
backoff = settings.mqtt_backoff_min
|
||||
while True:
|
||||
try:
|
||||
log.info("MQTT connecting host=%s port=%s", MQTT_BROKER, MQTT_PORT)
|
||||
log.info("MQTT connecting host=%s port=%s",
|
||||
settings.mqtt_broker, settings.mqtt_port)
|
||||
disconnected.clear()
|
||||
await client.connect(MQTT_BROKER, port=MQTT_PORT, ssl=ssl_ctx)
|
||||
await client.connect(settings.mqtt_broker, port=settings.mqtt_port,
|
||||
ssl=ssl_ctx)
|
||||
log.info("MQTT connect() returned")
|
||||
backoff = MQTT_BACKOFF_MIN # reset after a successful connect
|
||||
backoff = settings.mqtt_backoff_min
|
||||
await disconnected.wait()
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
|
|
@ -301,13 +315,13 @@ async def mqtt_loop():
|
|||
await asyncio.sleep(backoff)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
backoff = next_backoff(backoff, MQTT_BACKOFF_MAX)
|
||||
backoff = next_backoff(backoff, settings.mqtt_backoff_max)
|
||||
|
||||
|
||||
async def slack_summary_loop():
|
||||
"""Periodically send a summary of active devices to Slack."""
|
||||
while True:
|
||||
await asyncio.sleep(SUMMARY_INTERVAL)
|
||||
await asyncio.sleep(settings.summary_interval)
|
||||
snapshot = active_snapshot()
|
||||
if not snapshot:
|
||||
log.debug("Slack no active devices to summarize")
|
||||
|
|
@ -329,8 +343,9 @@ async def slack_summary_loop():
|
|||
|
||||
async def main():
|
||||
"""Start WebSocket server, MQTT loop, and Slack summary loop."""
|
||||
async with serve(ws_handler, WS_HOST, WS_PORT):
|
||||
log.info("WS server listening url=ws://%s:%s", WS_HOST, WS_PORT)
|
||||
async with serve(ws_handler, settings.ws_host, settings.ws_port):
|
||||
log.info("WS server listening url=ws://%s:%s",
|
||||
settings.ws_host, settings.ws_port)
|
||||
await asyncio.gather(
|
||||
mqtt_loop(),
|
||||
slack_summary_loop(),
|
||||
|
|
|
|||
|
|
@ -1,36 +1,47 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import subprocess
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import psutil
|
||||
from gmqtt import Client as MQTTClient
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
|
||||
class AgentSettings(BaseSettings):
|
||||
"""Typed configuration for the edge-device agent."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
||||
)
|
||||
|
||||
device_id: str = "device-iot-001"
|
||||
mqtt_broker: str = ""
|
||||
mqtt_port: int = Field(default=8883, ge=1, le=65535)
|
||||
ca_cert: str = "AmazonRootCA1.pem"
|
||||
client_cert: str = "device.cert.pem"
|
||||
client_key: str = "device.private.key"
|
||||
interval: int = Field(default=10, ge=1)
|
||||
agent_cpu_warn: float = Field(default=90.0, ge=0, le=100)
|
||||
log_level: str = "INFO"
|
||||
|
||||
|
||||
settings = AgentSettings()
|
||||
|
||||
logging.basicConfig(
|
||||
level=os.getenv("LOG_LEVEL", "INFO").upper(),
|
||||
level=settings.log_level.upper(),
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
log = logging.getLogger("edgewatch.agent")
|
||||
|
||||
DEVICE_ID = os.getenv("DEVICE_ID", "device-iot-001")
|
||||
MQTT_BROKER = os.getenv("MQTT_BROKER") # AWS IoT endpoint (xxx-ats.iot.<region>.amazonaws.com)
|
||||
MQTT_PORT = int(os.getenv("MQTT_PORT", "8883"))
|
||||
CA_CERT = os.getenv("CA_CERT", "AmazonRootCA1.pem")
|
||||
CLIENT_CERT = os.getenv("CLIENT_CERT", "device.cert.pem")
|
||||
CLIENT_KEY = os.getenv("CLIENT_KEY", "device.private.key")
|
||||
INTERVAL = int(os.getenv("INTERVAL", "10"))
|
||||
# Local warn threshold for stdout logs (Slack alerting lives in the backend).
|
||||
AGENT_CPU_WARN = float(os.getenv("AGENT_CPU_WARN", "90"))
|
||||
TOPIC = f"devices/{DEVICE_ID}/metrics"
|
||||
TOPIC = f"devices/{settings.device_id}/metrics"
|
||||
|
||||
# Process reference to measure agent overhead
|
||||
process = psutil.Process(os.getpid())
|
||||
process = psutil.Process()
|
||||
|
||||
|
||||
def collect_metrics():
|
||||
|
|
@ -48,7 +59,7 @@ def collect_metrics():
|
|||
]
|
||||
).decode().strip()
|
||||
gpu_util, _, _ = out.split(",")
|
||||
gpu = float(gpu_util)
|
||||
gpu: Optional[float] = float(gpu_util)
|
||||
except Exception: # noqa: BLE001
|
||||
gpu = None # No GPU available
|
||||
|
||||
|
|
@ -56,7 +67,7 @@ def collect_metrics():
|
|||
agent_cpu = process.cpu_percent(interval=None) # %
|
||||
|
||||
return {
|
||||
"device_id": DEVICE_ID,
|
||||
"device_id": settings.device_id,
|
||||
"timestamp": int(time.time()),
|
||||
"cpu_percent": cpu,
|
||||
"mem_percent": mem,
|
||||
|
|
@ -69,43 +80,41 @@ def collect_metrics():
|
|||
|
||||
async def main():
|
||||
"""Main loop: connect to MQTT broker and periodically publish metrics."""
|
||||
client = MQTTClient(DEVICE_ID)
|
||||
client = MQTTClient(settings.device_id)
|
||||
|
||||
# Event handlers (unused args prefixed with _ to satisfy linting)
|
||||
def on_connect(_client, _flags, _rc, _properties):
|
||||
log.info("agent connected device=%s", DEVICE_ID)
|
||||
log.info("agent connected device=%s", settings.device_id)
|
||||
|
||||
def on_disconnect(_client, _packet, _exc=None):
|
||||
log.warning("agent disconnected device=%s", DEVICE_ID)
|
||||
log.warning("agent disconnected device=%s", settings.device_id)
|
||||
|
||||
client.on_connect = on_connect
|
||||
client.on_disconnect = on_disconnect
|
||||
|
||||
# TLS configuration
|
||||
ssl_ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ssl_ctx.load_verify_locations(CA_CERT)
|
||||
ssl_ctx.load_cert_chain(certfile=CLIENT_CERT, keyfile=CLIENT_KEY)
|
||||
ssl_ctx.load_verify_locations(settings.ca_cert)
|
||||
ssl_ctx.load_cert_chain(certfile=settings.client_cert,
|
||||
keyfile=settings.client_key)
|
||||
|
||||
# Connect to AWS IoT
|
||||
await client.connect(MQTT_BROKER, MQTT_PORT, ssl=ssl_ctx)
|
||||
await client.connect(settings.mqtt_broker, settings.mqtt_port, ssl=ssl_ctx)
|
||||
|
||||
try:
|
||||
while True:
|
||||
metrics = collect_metrics()
|
||||
|
||||
if metrics["cpu_percent"] > AGENT_CPU_WARN:
|
||||
if metrics["cpu_percent"] > settings.agent_cpu_warn:
|
||||
log.warning("local CPU warn device=%s cpu=%.1f",
|
||||
DEVICE_ID, metrics["cpu_percent"])
|
||||
settings.device_id, metrics["cpu_percent"])
|
||||
|
||||
log.debug("agent self_cpu=%.1f self_mem_mb=%.2f",
|
||||
metrics["agent_cpu_percent"], metrics["agent_mem_mb"])
|
||||
|
||||
client.publish(TOPIC, json.dumps(metrics), qos=1, retain=False)
|
||||
|
||||
await asyncio.sleep(INTERVAL)
|
||||
await asyncio.sleep(settings.interval)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
log.info("agent cancelled device=%s", DEVICE_ID)
|
||||
log.info("agent cancelled device=%s", settings.device_id)
|
||||
finally:
|
||||
await client.disconnect()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.4
|
||||
aiosignal==1.4.0
|
||||
annotated-types==0.7.0
|
||||
attrs==25.3.0
|
||||
frozenlist==1.7.0
|
||||
gmqtt==0.7.0
|
||||
|
|
@ -8,7 +9,11 @@ idna==3.10
|
|||
multidict==6.6.4
|
||||
propcache==0.3.2
|
||||
psutil==7.0.0
|
||||
pydantic==2.13.4
|
||||
pydantic-core==2.46.4
|
||||
pydantic-settings==2.14.1
|
||||
python-dotenv==1.2.2
|
||||
typing-inspection==0.4.2
|
||||
typing_extensions==4.15.0
|
||||
websockets==15.0.1
|
||||
yarl==1.20.1
|
||||
|
|
|
|||
|
|
@ -18,28 +18,28 @@ def _fake_ws(target: str):
|
|||
|
||||
def test_open_mode_accepts_any_handshake(monkeypatch):
|
||||
# Empty WS_AUTH_TOKEN = open mode (dev-friendly default).
|
||||
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "")
|
||||
monkeypatch.setattr(backend.settings, "ws_auth_token", "")
|
||||
assert backend._ws_token_ok(_fake_ws("/")) is True
|
||||
assert backend._ws_token_ok(_fake_ws("/?token=anything")) is True
|
||||
|
||||
|
||||
def test_rejects_missing_token(monkeypatch):
|
||||
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
|
||||
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
|
||||
assert backend._ws_token_ok(_fake_ws("/")) is False
|
||||
|
||||
|
||||
def test_rejects_wrong_token(monkeypatch):
|
||||
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
|
||||
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
|
||||
assert backend._ws_token_ok(_fake_ws("/?token=wrong")) is False
|
||||
|
||||
|
||||
def test_accepts_correct_token(monkeypatch):
|
||||
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
|
||||
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
|
||||
assert backend._ws_token_ok(_fake_ws("/?token=s3cret")) is True
|
||||
|
||||
|
||||
def test_accepts_token_with_extra_query_params(monkeypatch):
|
||||
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
|
||||
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
|
||||
assert backend._ws_token_ok(_fake_ws("/?foo=bar&token=s3cret&x=1")) is True
|
||||
|
||||
|
||||
|
|
@ -47,5 +47,5 @@ def test_rejects_empty_token_value(monkeypatch):
|
|||
# ?token= with no value should never match a configured secret,
|
||||
# otherwise compare_digest("", "") would let everyone in if
|
||||
# WS_AUTH_TOKEN were ever accidentally set to "".
|
||||
monkeypatch.setattr(backend, "WS_AUTH_TOKEN", "s3cret")
|
||||
monkeypatch.setattr(backend.settings, "ws_auth_token", "s3cret")
|
||||
assert backend._ws_token_ok(_fake_ws("/?token=")) is False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue