chore(stability): harden runtime, observability, and test coverage
Pacote de endurecimento focado em rodar 24/7 no NAS sem surpresas. DB / persistence - Register a datetime adapter so Python 3.12+ stops warning about the deprecated default; emits "YYYY-MM-DD HH:MM:SS[.ffffff]" to preserve existing string-based comparisons in get_enhanced_stats. - PRAGMA journal_mode=WAL on init (concurrency + crash safety on the NAS) and synchronous=NORMAL per connection (cheap, WAL-safe). - Two composite indexes: (author, subreddit) and (media_type, downloaded_at) — speed up the gallery and cleanup queries. API surface - New /health router with no auth dep, returns booleans for db / ffmpeg / downloads_writable / scheduler plus the package version and whether Basic Auth is configured. Reachable even when RMC_AUTH_USER/PASS are set, so it can be wired into Container Manager / external monitors. - Refactor: move require_auth() from app-level dependency to per-router include_router(..., dependencies=[require_auth()]); /health is the only route that opts out. - Rate limit (60/min per IP+path) on every POST/PUT/DELETE in config: subreddits, users, blacklist.*, settings.* — protects config.yaml from runaway loops or spam. Observability - Replace silent `except Exception` in main and scheduler with exc_info=True logging and a new collector_status["last_error"] field surfaced via /api/collector/status, so a failed 3am scheduled run isn't invisible. - RotatingFileHandler (10 MB × 5 backups) caps collector.log around 60 MB on the NAS volume. Quality - tests/test_router_scheduler.py covers status / config (interval + cron) / history / run-now / collector run / individual collect validation. scheduler.py coverage 25% → 55%, total 55% → 58% (135 → 151 tests). - Type extractors.gfycat (cast yt-dlp Any to str), removed from the mypy ignore list. First module destravado — pattern for the rest. Verified locally: ruff + mypy + 151 pytest green; docker container boots healthy in 6s; PRAGMA journal_mode reports `wal`; 70 POSTs to /api/subreddits returned 60 OK + 10 HTTP 429.
This commit is contained in:
parent
6e7aba37db
commit
a2727a9ac4
11 changed files with 362 additions and 35 deletions
|
|
@ -82,7 +82,6 @@ module = [
|
|||
"src.reddit_client",
|
||||
"src.database",
|
||||
"src.extractors.reddit",
|
||||
"src.extractors.gfycat",
|
||||
]
|
||||
ignore_errors = true
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,17 @@
|
|||
"""Configuration loader and validator."""
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
# 10 MB x 5 backups — caps collector.log at ~60 MB on the NAS volume.
|
||||
_LOG_MAX_BYTES = 10 * 1024 * 1024
|
||||
_LOG_BACKUP_COUNT = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubredditTarget:
|
||||
|
|
@ -151,7 +156,12 @@ def setup_logging(config: LoggingConfig) -> logging.Logger:
|
|||
logger.addHandler(console_handler)
|
||||
|
||||
if config.file:
|
||||
file_handler = logging.FileHandler(config.file, encoding="utf-8")
|
||||
file_handler = logging.handlers.RotatingFileHandler(
|
||||
config.file,
|
||||
maxBytes=_LOG_MAX_BYTES,
|
||||
backupCount=_LOG_BACKUP_COUNT,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,12 @@ from pathlib import Path
|
|||
DEFAULT_DB_PATH = os.environ.get("RMC_DB_PATH", "media.db")
|
||||
|
||||
|
||||
# Python 3.12 deprecated sqlite3's default datetime adapter; register our own
|
||||
# emitting "YYYY-MM-DD HH:MM:SS[.ffffff]" so existing string-based comparisons
|
||||
# in get_enhanced_stats() keep working unchanged.
|
||||
sqlite3.register_adapter(datetime, lambda d: d.isoformat(sep=" "))
|
||||
|
||||
|
||||
def _resolve_db_path() -> str:
|
||||
"""Read RMC_DB_PATH lazily so tests can monkeypatch the env per case."""
|
||||
return os.environ.get("RMC_DB_PATH", "media.db")
|
||||
|
|
@ -81,6 +87,11 @@ class Database:
|
|||
conn.execute("CREATE INDEX IF NOT EXISTS idx_subreddit ON posts(subreddit)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_hash ON posts(file_hash)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_downloaded ON posts(downloaded_at)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_author_subreddit ON posts(author, subreddit)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_media_downloaded ON posts(media_type, downloaded_at)")
|
||||
# WAL is persisted on the DB header — set once is enough, but the
|
||||
# PRAGMA is idempotent so running it on every _init_db is fine.
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Scheduler history table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS scheduler_history (
|
||||
|
|
@ -100,6 +111,8 @@ class Database:
|
|||
"""Context manager for database connections."""
|
||||
conn = sqlite3.connect(self.db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
# NORMAL is safe under WAL and avoids the fsync-per-commit penalty.
|
||||
conn.execute("PRAGMA synchronous=NORMAL")
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
|
|
|
|||
|
|
@ -22,13 +22,13 @@ def extract_gfycat_url(url: str) -> str | None:
|
|||
info = ydl.extract_info(url, download=False)
|
||||
|
||||
if info and "url" in info:
|
||||
return info["url"]
|
||||
return str(info["url"])
|
||||
|
||||
if info and "formats" in info:
|
||||
mp4_formats = [f for f in info["formats"] if f.get("ext") == "mp4" and f.get("url")]
|
||||
if mp4_formats:
|
||||
best = max(mp4_formats, key=lambda x: x.get("height", 0))
|
||||
return best["url"]
|
||||
return str(best["url"])
|
||||
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -223,8 +223,8 @@ def collect(config: Config, logger) -> CollectionStats:
|
|||
try:
|
||||
for post in client.get_subreddit_posts(target):
|
||||
process_post(post, client, db, downloader, config, stats, logger, source_type="subreddit")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing r/{target.name}: {e}")
|
||||
except Exception:
|
||||
logger.error("Error processing r/%s", target.name, exc_info=True)
|
||||
stats.errors += 1
|
||||
|
||||
for target in config.targets.users:
|
||||
|
|
@ -232,8 +232,8 @@ def collect(config: Config, logger) -> CollectionStats:
|
|||
try:
|
||||
for post in client.get_user_posts(target):
|
||||
process_post(post, client, db, downloader, config, stats, logger, source_type="user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing u/{target.name}: {e}")
|
||||
except Exception:
|
||||
logger.error("Error processing u/%s", target.name, exc_info=True)
|
||||
stats.errors += 1
|
||||
|
||||
return stats
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""FastAPI web application for managing Reddit Media Collector."""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
|
|
@ -10,7 +11,7 @@ from fastapi.templating import Jinja2Templates
|
|||
|
||||
from . import config_manager
|
||||
from .auth import require_auth
|
||||
from .routers import config, favorites, media, scheduler, stats
|
||||
from .routers import config, favorites, health, media, scheduler, stats
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
|
@ -24,19 +25,29 @@ async def lifespan(app: FastAPI):
|
|||
scheduler.scheduler.shutdown(wait=False)
|
||||
|
||||
|
||||
try:
|
||||
_app_version = version("reddit-media-collector")
|
||||
except PackageNotFoundError:
|
||||
_app_version = "0.0.0+unknown"
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Reddit Media Collector",
|
||||
version="1.0.0",
|
||||
version=_app_version,
|
||||
lifespan=lifespan,
|
||||
dependencies=[require_auth()],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(config.router)
|
||||
app.include_router(media.router)
|
||||
app.include_router(stats.router)
|
||||
app.include_router(scheduler.router)
|
||||
app.include_router(favorites.router)
|
||||
# /health is the only route exposed without auth — used by container
|
||||
# healthcheck, dashboards, and external monitors. Everything else is
|
||||
# protected when RMC_AUTH_USER/RMC_AUTH_PASS are set.
|
||||
app.include_router(health.router)
|
||||
|
||||
_auth_dep = [require_auth()]
|
||||
app.include_router(config.router, dependencies=_auth_dep)
|
||||
app.include_router(media.router, dependencies=_auth_dep)
|
||||
app.include_router(stats.router, dependencies=_auth_dep)
|
||||
app.include_router(scheduler.router, dependencies=_auth_dep)
|
||||
app.include_router(favorites.router, dependencies=_auth_dep)
|
||||
|
||||
_static_dir = Path(__file__).parent / "static"
|
||||
app.mount("/static", StaticFiles(directory=_static_dir), name="static")
|
||||
|
|
@ -44,7 +55,7 @@ app.mount("/static", StaticFiles(directory=_static_dir), name="static")
|
|||
templates = Jinja2Templates(directory=Path(__file__).parent / "templates")
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
@app.get("/", response_class=HTMLResponse, dependencies=[require_auth()])
|
||||
async def index(request: Request):
|
||||
"""Render main page."""
|
||||
subreddits = config_manager.get_subreddits()
|
||||
|
|
|
|||
|
|
@ -25,4 +25,11 @@ SCHEDULER_DB_PATH = Path(os.environ.get("RMC_SCHEDULER_DB", str(PROJECT_DIR / "s
|
|||
|
||||
# Collector state (protected by lock for thread safety)
|
||||
collector_lock = threading.Lock()
|
||||
collector_status: dict = {"running": False, "last_run": None, "last_result": None}
|
||||
collector_status: dict = {
|
||||
"running": False,
|
||||
"last_run": None,
|
||||
"last_result": None,
|
||||
# Populated when a scheduled run fails; mirrored to the UI so silent 3am
|
||||
# failures are visible without grepping the log.
|
||||
"last_error": None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
"""Configuration and blacklist API routes."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .. import config_manager
|
||||
from ..rate_limit import rate_limit
|
||||
|
||||
router = APIRouter(tags=["config"])
|
||||
|
||||
# 60 mutations/minute per (IP, path) — protects config.yaml from runaway
|
||||
# loops or naive scripted spam. Shared dep instance; rate_limit's bucket key
|
||||
# already includes request.url.path, so each route gets its own counter.
|
||||
_mut = Depends(rate_limit(60, 60))
|
||||
|
||||
|
||||
class SubredditCreate(BaseModel):
|
||||
name: str
|
||||
|
|
@ -35,7 +41,7 @@ async def list_subreddits():
|
|||
return config_manager.get_subreddits()
|
||||
|
||||
|
||||
@router.post("/api/subreddits")
|
||||
@router.post("/api/subreddits", dependencies=[_mut])
|
||||
async def add_subreddit(data: SubredditCreate):
|
||||
"""Add a new subreddit."""
|
||||
if not data.name:
|
||||
|
|
@ -48,7 +54,7 @@ async def add_subreddit(data: SubredditCreate):
|
|||
return {"message": f"Subreddit '{data.name}' added successfully"}
|
||||
|
||||
|
||||
@router.delete("/api/subreddits/{name}")
|
||||
@router.delete("/api/subreddits/{name}", dependencies=[_mut])
|
||||
async def delete_subreddit(name: str):
|
||||
"""Remove a subreddit."""
|
||||
success = config_manager.remove_subreddit(name)
|
||||
|
|
@ -64,7 +70,7 @@ async def list_users():
|
|||
return config_manager.get_users()
|
||||
|
||||
|
||||
@router.post("/api/users")
|
||||
@router.post("/api/users", dependencies=[_mut])
|
||||
async def add_user(data: UserCreate):
|
||||
"""Add a new user."""
|
||||
if not data.name:
|
||||
|
|
@ -77,7 +83,7 @@ async def add_user(data: UserCreate):
|
|||
return {"message": f"User '{data.name}' added successfully"}
|
||||
|
||||
|
||||
@router.delete("/api/users/{name}")
|
||||
@router.delete("/api/users/{name}", dependencies=[_mut])
|
||||
async def delete_user(name: str):
|
||||
"""Remove a user."""
|
||||
success = config_manager.remove_user(name)
|
||||
|
|
@ -96,7 +102,7 @@ async def get_blacklist():
|
|||
return config_manager.get_blacklist()
|
||||
|
||||
|
||||
@router.post("/api/blacklist/authors")
|
||||
@router.post("/api/blacklist/authors", dependencies=[_mut])
|
||||
async def add_blacklist_author(data: BlacklistItem):
|
||||
"""Add an author to the blacklist."""
|
||||
if not data.value:
|
||||
|
|
@ -109,7 +115,7 @@ async def add_blacklist_author(data: BlacklistItem):
|
|||
return {"message": f"Author '{data.value}' added to blacklist"}
|
||||
|
||||
|
||||
@router.delete("/api/blacklist/authors/{author}")
|
||||
@router.delete("/api/blacklist/authors/{author}", dependencies=[_mut])
|
||||
async def remove_blacklist_author(author: str):
|
||||
"""Remove an author from the blacklist."""
|
||||
success = config_manager.remove_blacklist_author(author)
|
||||
|
|
@ -119,7 +125,7 @@ async def remove_blacklist_author(author: str):
|
|||
return {"message": f"Author '{author}' removed from blacklist"}
|
||||
|
||||
|
||||
@router.post("/api/blacklist/subreddits")
|
||||
@router.post("/api/blacklist/subreddits", dependencies=[_mut])
|
||||
async def add_blacklist_subreddit(data: BlacklistItem):
|
||||
"""Add a subreddit to the blacklist."""
|
||||
if not data.value:
|
||||
|
|
@ -132,7 +138,7 @@ async def add_blacklist_subreddit(data: BlacklistItem):
|
|||
return {"message": f"Subreddit '{data.value}' added to blacklist"}
|
||||
|
||||
|
||||
@router.delete("/api/blacklist/subreddits/{subreddit}")
|
||||
@router.delete("/api/blacklist/subreddits/{subreddit}", dependencies=[_mut])
|
||||
async def remove_blacklist_subreddit(subreddit: str):
|
||||
"""Remove a subreddit from the blacklist."""
|
||||
success = config_manager.remove_blacklist_subreddit(subreddit)
|
||||
|
|
@ -142,7 +148,7 @@ async def remove_blacklist_subreddit(subreddit: str):
|
|||
return {"message": f"Subreddit '{subreddit}' removed from blacklist"}
|
||||
|
||||
|
||||
@router.post("/api/blacklist/keywords")
|
||||
@router.post("/api/blacklist/keywords", dependencies=[_mut])
|
||||
async def add_blacklist_keyword(data: BlacklistItem):
|
||||
"""Add a title keyword to the blacklist."""
|
||||
if not data.value:
|
||||
|
|
@ -155,7 +161,7 @@ async def add_blacklist_keyword(data: BlacklistItem):
|
|||
return {"message": f"Keyword '{data.value}' added to blacklist"}
|
||||
|
||||
|
||||
@router.delete("/api/blacklist/keywords/{keyword:path}")
|
||||
@router.delete("/api/blacklist/keywords/{keyword:path}", dependencies=[_mut])
|
||||
async def remove_blacklist_keyword(keyword: str):
|
||||
"""Remove a title keyword from the blacklist."""
|
||||
success = config_manager.remove_blacklist_keyword(keyword)
|
||||
|
|
@ -165,7 +171,7 @@ async def remove_blacklist_keyword(keyword: str):
|
|||
return {"message": f"Keyword '{keyword}' removed from blacklist"}
|
||||
|
||||
|
||||
@router.post("/api/blacklist/domains")
|
||||
@router.post("/api/blacklist/domains", dependencies=[_mut])
|
||||
async def add_blacklist_domain(data: BlacklistItem):
|
||||
"""Add a domain to the blacklist."""
|
||||
if not data.value:
|
||||
|
|
@ -178,7 +184,7 @@ async def add_blacklist_domain(data: BlacklistItem):
|
|||
return {"message": f"Domain '{data.value}' added to blacklist"}
|
||||
|
||||
|
||||
@router.delete("/api/blacklist/domains/{domain:path}")
|
||||
@router.delete("/api/blacklist/domains/{domain:path}", dependencies=[_mut])
|
||||
async def remove_blacklist_domain(domain: str):
|
||||
"""Remove a domain from the blacklist."""
|
||||
success = config_manager.remove_blacklist_domain(domain)
|
||||
|
|
@ -210,7 +216,7 @@ async def get_settings():
|
|||
}
|
||||
|
||||
|
||||
@router.put("/api/settings/download")
|
||||
@router.put("/api/settings/download", dependencies=[_mut])
|
||||
async def update_download_settings(settings: DownloadSettings):
|
||||
"""Update download settings."""
|
||||
config = config_manager.load_config()
|
||||
|
|
@ -232,7 +238,7 @@ class RateLimitSettings(BaseModel):
|
|||
download_delay_seconds: float = 2.0
|
||||
|
||||
|
||||
@router.put("/api/settings/rate-limit")
|
||||
@router.put("/api/settings/rate-limit", dependencies=[_mut])
|
||||
async def update_rate_limit_settings(settings: RateLimitSettings):
|
||||
"""Update rate limit settings."""
|
||||
config = config_manager.load_config()
|
||||
|
|
|
|||
72
src/web/routers/health.py
Normal file
72
src/web/routers/health.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Public health endpoint — always 200, body reports component status."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import suppress
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ...database import Database
|
||||
from ..deps import DOWNLOADS_DIR
|
||||
|
||||
router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
def _package_version() -> str:
|
||||
try:
|
||||
return version("reddit-media-collector")
|
||||
except PackageNotFoundError:
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _db_ok() -> bool:
|
||||
try:
|
||||
with Database()._get_connection() as conn:
|
||||
conn.execute("SELECT 1").fetchone()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _ffmpeg_ok() -> bool:
|
||||
return shutil.which("ffmpeg") is not None
|
||||
|
||||
|
||||
def _downloads_writable() -> bool:
|
||||
probe = Path(DOWNLOADS_DIR) / ".healthcheck"
|
||||
try:
|
||||
DOWNLOADS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
probe.write_text("ok", encoding="utf-8")
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
finally:
|
||||
with suppress(OSError):
|
||||
probe.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def _scheduler_ok() -> bool:
|
||||
# Imported lazily to dodge circular import (scheduler router imports deps).
|
||||
try:
|
||||
from . import scheduler
|
||||
|
||||
return bool(scheduler.scheduler.running)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> dict:
|
||||
"""Report status of each subsystem. Always 200 — body tells the truth."""
|
||||
return {
|
||||
"version": _package_version(),
|
||||
"db": _db_ok(),
|
||||
"ffmpeg": _ffmpeg_ok(),
|
||||
"downloads_writable": _downloads_writable(),
|
||||
"scheduler": _scheduler_ok(),
|
||||
"auth_enabled": bool(os.environ.get("RMC_AUTH_USER") and os.environ.get("RMC_AUTH_PASS")),
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
"""Scheduler and collector API routes."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
|
|
@ -25,6 +26,8 @@ from ..deps import (
|
|||
collector_status,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(tags=["scheduler"])
|
||||
|
||||
# Scheduler setup
|
||||
|
|
@ -146,14 +149,18 @@ def _run_collector_scheduled():
|
|||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Scheduled collector run %s timed out", run_id)
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "timeout"
|
||||
collector_status["last_error"] = "Execution timed out after 4 hours"
|
||||
db.finish_scheduler_run(
|
||||
run_id, "timeout", posts_processed, posts_downloaded, "Execution timed out after 4 hours"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Scheduled collector run %s failed", run_id, exc_info=True)
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = f"error: {e!s}"
|
||||
collector_status["last_error"] = str(e)[:500]
|
||||
db.finish_scheduler_run(run_id, "error", posts_processed, posts_downloaded, str(e)[:500])
|
||||
finally:
|
||||
with collector_lock:
|
||||
|
|
@ -175,13 +182,24 @@ def _run_collector():
|
|||
timeout=14400,
|
||||
)
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "success" if result.returncode == 0 else "error"
|
||||
if result.returncode == 0:
|
||||
collector_status["last_result"] = "success"
|
||||
collector_status["last_error"] = None
|
||||
else:
|
||||
collector_status["last_result"] = "error"
|
||||
collector_status["last_error"] = (result.stderr or "")[:500] or "collector exited non-zero"
|
||||
if result.returncode != 0:
|
||||
logger.error("Manual collector run exited %s: %s", result.returncode, (result.stderr or "")[:500])
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Manual collector run timed out")
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "timeout"
|
||||
collector_status["last_error"] = "Execution timed out after 4 hours"
|
||||
except Exception as e:
|
||||
logger.error("Manual collector run failed", exc_info=True)
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = f"error: {e!s}"
|
||||
collector_status["last_error"] = str(e)[:500]
|
||||
finally:
|
||||
with collector_lock:
|
||||
collector_status["running"] = False
|
||||
|
|
@ -250,12 +268,16 @@ def _run_individual_collection(target_type: str, target_name: str, media_types:
|
|||
os.unlink(tmp_config_path)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.error("Individual collection %s/%s timed out", target_type, target_name)
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "timeout"
|
||||
collector_status["last_error"] = f"{target_type}/{target_name}: execution timed out"
|
||||
db.finish_scheduler_run(run_id, "timeout", posts_processed, posts_downloaded, "Execution timed out")
|
||||
except Exception as e:
|
||||
logger.error("Individual collection %s/%s failed", target_type, target_name, exc_info=True)
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = f"error: {e!s}"
|
||||
collector_status["last_error"] = f"{target_type}/{target_name}: {str(e)[:400]}"
|
||||
db.finish_scheduler_run(run_id, "error", posts_processed, posts_downloaded, str(e)[:500])
|
||||
finally:
|
||||
with collector_lock:
|
||||
|
|
|
|||
187
tests/test_router_scheduler.py
Normal file
187
tests/test_router_scheduler.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for /api/scheduler and /api/collector routers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.database import Database
|
||||
from src.web import deps
|
||||
from src.web.app import app
|
||||
from src.web.routers import scheduler as scheduler_router
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_collector_status():
|
||||
"""Each test starts with a clean collector_status snapshot."""
|
||||
snapshot = dict(deps.collector_status)
|
||||
deps.collector_status.update({"running": False, "last_run": None, "last_result": None, "last_error": None})
|
||||
yield
|
||||
deps.collector_status.clear()
|
||||
deps.collector_status.update(snapshot)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def isolated_paths(tmp_path, monkeypatch):
|
||||
"""Point scheduler config + DB + RMC_DB_PATH at tmp_path."""
|
||||
cfg = tmp_path / "scheduler_config.yaml"
|
||||
db_path = tmp_path / "media.db"
|
||||
monkeypatch.setattr(scheduler_router, "SCHEDULER_CONFIG_PATH", cfg)
|
||||
monkeypatch.setenv("RMC_DB_PATH", str(db_path))
|
||||
# Pre-create the schema so /api/scheduler/history doesn't 500.
|
||||
Database(db_path=str(db_path))
|
||||
return SimpleNamespace(cfg=cfg, db_path=db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_subprocess(monkeypatch):
|
||||
"""Replace subprocess.run in the scheduler module with a recorded fake."""
|
||||
calls = []
|
||||
|
||||
def fake_run(cmd, **kwargs):
|
||||
calls.append((cmd, kwargs))
|
||||
return SimpleNamespace(returncode=0, stdout="Posts processed: 3\nNew downloads: 1\n", stderr="")
|
||||
|
||||
monkeypatch.setattr(scheduler_router.subprocess, "run", fake_run)
|
||||
return calls
|
||||
|
||||
|
||||
class TestSchedulerStatus:
|
||||
def test_status_structure(self, client, isolated_paths):
|
||||
response = client.get("/api/scheduler/status")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert {"enabled", "interval_hours", "mode", "specific_times", "next_run", "is_running", "last_run"} <= set(
|
||||
data
|
||||
)
|
||||
|
||||
def test_collector_status_exposes_last_error(self, client):
|
||||
deps.collector_status["last_error"] = "kaboom: boom"
|
||||
response = client.get("/api/collector/status")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["last_error"] == "kaboom: boom"
|
||||
|
||||
|
||||
class TestSchedulerConfig:
|
||||
def test_put_interval_mode_persists(self, client, isolated_paths):
|
||||
payload = {"enabled": False, "interval_hours": 8, "mode": "interval"}
|
||||
response = client.put("/api/scheduler/config", json=payload)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["config"]["interval_hours"] == 8
|
||||
assert body["config"]["mode"] == "interval"
|
||||
# YAML round-trip
|
||||
assert isolated_paths.cfg.exists()
|
||||
assert "interval_hours: 8" in isolated_paths.cfg.read_text()
|
||||
|
||||
# And the GET reflects it
|
||||
status = client.get("/api/scheduler/status").json()
|
||||
assert status["interval_hours"] == 8
|
||||
|
||||
def test_put_cron_mode_persists(self, client, isolated_paths):
|
||||
payload = {
|
||||
"enabled": False,
|
||||
"interval_hours": 6,
|
||||
"mode": "cron",
|
||||
"specific_times": ["02:00", "14:00"],
|
||||
}
|
||||
response = client.put("/api/scheduler/config", json=payload)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["config"]["specific_times"] == ["02:00", "14:00"]
|
||||
|
||||
|
||||
class TestSchedulerHistory:
|
||||
def test_empty_history(self, client, isolated_paths):
|
||||
response = client.get("/api/scheduler/history")
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
|
||||
def test_history_returns_runs(self, client, isolated_paths):
|
||||
db = Database(db_path=str(isolated_paths.db_path))
|
||||
run_id = db.add_scheduler_run(datetime.now())
|
||||
db.finish_scheduler_run(run_id, "success", posts_processed=10, posts_downloaded=3)
|
||||
|
||||
response = client.get("/api/scheduler/history")
|
||||
assert response.status_code == 200
|
||||
runs = response.json()
|
||||
assert len(runs) == 1
|
||||
assert runs[0]["status"] == "success"
|
||||
assert runs[0]["posts_processed"] == 10
|
||||
|
||||
def test_history_limit_clamped(self, client, isolated_paths):
|
||||
response = client.get("/api/scheduler/history?limit=500")
|
||||
# Query validates limit <= 100
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestRunNow:
|
||||
def test_returns_409_when_already_running(self, client):
|
||||
deps.collector_status["running"] = True
|
||||
response = client.post("/api/scheduler/run-now")
|
||||
assert response.status_code == 409
|
||||
|
||||
def test_triggers_background_task(self, client, isolated_paths, mock_subprocess):
|
||||
response = client.post("/api/scheduler/run-now")
|
||||
assert response.status_code == 200
|
||||
# BackgroundTasks runs synchronously after the response in TestClient,
|
||||
# so the fake subprocess should have been called once.
|
||||
assert len(mock_subprocess) == 1
|
||||
|
||||
|
||||
class TestCollectorRun:
|
||||
def test_returns_409_when_already_running(self, client):
|
||||
deps.collector_status["running"] = True
|
||||
response = client.post("/api/collector/run")
|
||||
assert response.status_code == 409
|
||||
|
||||
def test_triggers_background_task(self, client, mock_subprocess):
|
||||
response = client.post("/api/collector/run")
|
||||
assert response.status_code == 200
|
||||
assert len(mock_subprocess) == 1
|
||||
|
||||
|
||||
class TestIndividualCollect:
|
||||
def test_invalid_target_type(self, client):
|
||||
response = client.post(
|
||||
"/api/collect/individual",
|
||||
json={"target_type": "channel", "target_name": "test"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_empty_target_name(self, client):
|
||||
response = client.post(
|
||||
"/api/collect/individual",
|
||||
json={"target_type": "user", "target_name": ""},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_invalid_media_type(self, client):
|
||||
response = client.post(
|
||||
"/api/collect/individual",
|
||||
json={"target_type": "user", "target_name": "alice", "media_types": ["spam"]},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_returns_409_when_already_running(self, client):
|
||||
deps.collector_status["running"] = True
|
||||
response = client.post(
|
||||
"/api/collect/individual",
|
||||
json={"target_type": "user", "target_name": "alice"},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
|
||||
|
||||
class TestCollectTargets:
|
||||
def test_targets_structure(self, client, isolated_paths):
|
||||
response = client.get("/api/collect/targets")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert {"favorite_authors", "configured_subreddits", "configured_users"} <= set(data)
|
||||
Loading…
Add table
Add a link
Reference in a new issue