chore: checkpoint baseline (routers, tests, pyproject)

This commit is contained in:
authentik Default Admin 2026-05-17 11:40:35 +01:00
parent a58498a315
commit 986f1dfef4
37 changed files with 2527 additions and 1664 deletions

15
.dockerignore Normal file
View file

@ -0,0 +1,15 @@
__pycache__
*.pyc
.pytest_cache
.ruff_cache
.mypy_cache
.git
.gitignore
*.db
downloads/
data/
*.log
.env
venv/
.venv/
tests/

36
.github/workflows/ci.yml vendored Normal file
View file

@ -0,0 +1,36 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: pip install ruff
- run: ruff check src/
- run: ruff format --check src/
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- run: pip install -e ".[dev]"
- run: pytest tests/ --cov=src --cov-report=term-missing -v
docker:
runs-on: ubuntu-latest
needs: [lint, test]
steps:
- uses: actions/checkout@v4
- run: docker build -t reddit-media-collector .

6
.gitignore vendored
View file

@ -18,3 +18,9 @@ scheduler_config.yaml
.thumbs/
*.sqlite
.DS_Store
.coverage
.coverage.*
htmlcov/
.pytest_cache/
.ruff_cache/
.mypy_cache/

7
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.8
hooks:
- id: ruff
args: [--fix]
- id: ruff-format

35
Dockerfile Normal file
View file

@ -0,0 +1,35 @@
FROM python:3.12-slim AS base
RUN apt-get update && apt-get install -y --no-install-recommends \
ffmpeg \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /app
FROM base AS builder
COPY pyproject.toml ./
COPY src/ ./src/
RUN pip install --no-cache-dir .
FROM base AS runtime
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
COPY --from=builder /usr/local/bin /usr/local/bin
WORKDIR /app
COPY src/ ./src/
RUN mkdir -p /app/downloads /app/data
ENV RMC_DOWNLOAD_DIR=/app/downloads
ENV RMC_DB_PATH=/app/data/media.db
ENV RMC_CONFIG_PATH=/app/config.yaml
ENV RMC_TIMEZONE=UTC
EXPOSE 8000
VOLUME ["/app/downloads", "/app/data", "/app/config.yaml"]
CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "8000"]

15
docker-compose.yml Normal file
View file

@ -0,0 +1,15 @@
services:
collector:
build: .
container_name: reddit-media-collector
restart: unless-stopped
ports:
- "8000:8000"
volumes:
- ./downloads:/app/downloads
- ./data:/app/data
- ./config.yaml:/app/config.yaml:ro
environment:
- RMC_TIMEZONE=UTC
- RMC_DOWNLOAD_DIR=/app/downloads
- RMC_DB_PATH=/app/data/media.db

86
pyproject.toml Normal file
View file

@ -0,0 +1,86 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src"]
[project]
name = "reddit-media-collector"
version = "1.0.0"
description = "Self-hosted media collector for Reddit with Immich integration"
readme = "README.md"
requires-python = ">=3.11"
license = "MIT"
dependencies = [
"requests>=2.31.0",
"pyyaml>=6.0",
"yt-dlp>=2024.0.0",
"fastapi>=0.109.0",
"uvicorn>=0.27.0",
"jinja2>=3.1.0",
"apscheduler>=3.10.0",
"sqlalchemy>=2.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0.0",
"pytest-cov>=5.0.0",
"pytest-asyncio>=0.23.0",
"httpx>=0.27.0",
"ruff>=0.4.0",
"mypy>=1.10.0",
"types-requests>=2.31.0",
"types-PyYAML>=6.0",
"pre-commit>=3.7.0",
]
[project.scripts]
reddit-collector = "src.main:main"
[tool.ruff]
target-version = "py311"
line-length = 120
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"N", # pep8-naming
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"RUF", # ruff-specific rules
]
ignore = [
"E501", # line too long (handled by formatter)
"B008", # do not perform function calls in argument defaults (FastAPI uses this)
]
[tool.ruff.lint.isort]
known-first-party = ["src"]
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = false
check_untyped_defs = true
ignore_missing_imports = true
[tool.pytest.ini_options]
testpaths = ["tests"]
pythonpath = ["."]
addopts = "-v --tb=short"
[tool.coverage.run]
source = ["src"]
omit = ["src/web/templates/*"]
[tool.coverage.report]
show_missing = true
skip_empty = true

View file

@ -1,7 +0,0 @@
requests>=2.31.0
pyyaml>=6.0
Pillow>=10.0.0
yt-dlp>=2024.0.0
fastapi>=0.109.0
uvicorn>=0.27.0
jinja2>=3.1.0

View file

@ -1,9 +1,9 @@
"""Configuration loader and validator."""
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import yaml
@ -49,7 +49,7 @@ class RateLimitConfig:
@dataclass
class LoggingConfig:
level: str = "INFO"
file: Optional[str] = "collector.log"
file: str | None = "collector.log"
@dataclass
@ -79,7 +79,7 @@ def load_config(config_path: str = "config.yaml") -> Config:
"Please copy config.yaml.example to config.yaml and customize."
)
with open(path, "r", encoding="utf-8") as f:
with open(path, encoding="utf-8") as f:
data = yaml.safe_load(f)
targets_data = data.get("targets", {})
@ -92,10 +92,7 @@ def load_config(config_path: str = "config.yaml") -> Config:
)
for s in targets_data.get("subreddits", [])
]
users = [
UserTarget(name=u["name"], limit=u.get("limit", 30))
for u in targets_data.get("users", [])
]
users = [UserTarget(name=u["name"], limit=u.get("limit", 30)) for u in targets_data.get("users", [])]
targets = TargetsConfig(subreddits=subreddits, users=users)
if not targets.subreddits and not targets.users:
@ -103,7 +100,7 @@ def load_config(config_path: str = "config.yaml") -> Config:
download_data = data.get("download", {})
download = DownloadConfig(
output_dir=download_data.get("output_dir", "./downloads"),
output_dir=os.environ.get("RMC_DOWNLOAD_DIR", download_data.get("output_dir", "./downloads")),
media_types=download_data.get("media_types", ["image", "video", "gif"]),
min_score=download_data.get("min_score", 10),
skip_nsfw=download_data.get("skip_nsfw", True),
@ -147,9 +144,7 @@ def setup_logging(config: LoggingConfig) -> logging.Logger:
logger = logging.getLogger("reddit_collector")
logger.setLevel(getattr(logging, config.level.upper()))
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)

View file

@ -1,11 +1,13 @@
"""SQLite database for storing post metadata and tracking downloads."""
import os
import sqlite3
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Optional
DEFAULT_DB_PATH = os.environ.get("RMC_DB_PATH", "media.db")
@dataclass
@ -15,22 +17,22 @@ class PostRecord:
author: str
title: str
url: str
media_url: Optional[str]
media_type: Optional[str]
media_url: str | None
media_type: str | None
score: int
created_utc: float
downloaded_at: Optional[datetime]
local_path: Optional[str]
file_hash: Optional[str]
permalink: Optional[str] = None # Reddit permalink for Immich
source_type: Optional[str] = None # 'subreddit' or 'user'
flair: Optional[str] = None # Post flair for tagging
downloaded_at: datetime | None
local_path: str | None
file_hash: str | None
permalink: str | None = None # Reddit permalink for Immich
source_type: str | None = None # 'subreddit' or 'user'
flair: str | None = None # Post flair for tagging
class Database:
"""SQLite database wrapper for tracking downloaded posts."""
def __init__(self, db_path: str = "media.db"):
def __init__(self, db_path: str = DEFAULT_DB_PATH):
self.db_path = Path(db_path)
self._init_db()
@ -65,27 +67,15 @@ class Database:
)
""")
# Add new columns if they don't exist (migration)
try:
with suppress(sqlite3.OperationalError):
conn.execute("ALTER TABLE posts ADD COLUMN permalink TEXT")
except sqlite3.OperationalError:
pass
try:
with suppress(sqlite3.OperationalError):
conn.execute("ALTER TABLE posts ADD COLUMN source_type TEXT")
except sqlite3.OperationalError:
pass
try:
with suppress(sqlite3.OperationalError):
conn.execute("ALTER TABLE posts ADD COLUMN flair TEXT")
except sqlite3.OperationalError:
pass
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_subreddit ON posts(subreddit)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_file_hash ON posts(file_hash)"
)
conn.execute(
"CREATE INDEX IF NOT EXISTS idx_downloaded ON posts(downloaded_at)"
)
conn.execute("CREATE INDEX IF NOT EXISTS idx_subreddit ON posts(subreddit)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_hash ON posts(file_hash)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_downloaded ON posts(downloaded_at)")
# Scheduler history table
conn.execute("""
CREATE TABLE IF NOT EXISTS scheduler_history (
@ -113,17 +103,13 @@ class Database:
def post_exists(self, post_id: str) -> bool:
"""Check if a post has already been processed."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT 1 FROM posts WHERE id = ?", (post_id,)
)
cursor = conn.execute("SELECT 1 FROM posts WHERE id = ?", (post_id,))
return cursor.fetchone() is not None
def hash_exists(self, file_hash: str) -> Optional[str]:
def hash_exists(self, file_hash: str) -> str | None:
"""Check if a file with this hash already exists. Returns local_path if found."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT local_path FROM posts WHERE file_hash = ?", (file_hash,)
)
cursor = conn.execute("SELECT local_path FROM posts WHERE file_hash = ?", (file_hash,))
row = cursor.fetchone()
return row["local_path"] if row else None
@ -158,9 +144,7 @@ class Database:
)
conn.commit()
def mark_downloaded(
self, post_id: str, local_path: str, file_hash: str
) -> None:
def mark_downloaded(self, post_id: str, local_path: str, file_hash: str) -> None:
"""Mark a post as downloaded with file info."""
with self._get_connection() as conn:
conn.execute(
@ -182,12 +166,10 @@ class Database:
)
conn.commit()
def get_post(self, post_id: str) -> Optional[PostRecord]:
def get_post(self, post_id: str) -> PostRecord | None:
"""Get a post record by ID."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT * FROM posts WHERE id = ?", (post_id,)
)
cursor = conn.execute("SELECT * FROM posts WHERE id = ?", (post_id,))
row = cursor.fetchone()
if not row:
return None
@ -204,54 +186,49 @@ class Database:
downloaded_at=row["downloaded_at"],
local_path=row["local_path"],
file_hash=row["file_hash"],
permalink=row["permalink"] if "permalink" in row.keys() else None,
source_type=row["source_type"] if "source_type" in row.keys() else None,
flair=row["flair"] if "flair" in row.keys() else None,
permalink=row["permalink"],
source_type=row["source_type"],
flair=row["flair"],
)
def get_all_downloaded(self) -> list[PostRecord]:
"""Get all downloaded posts for migration."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT * FROM posts WHERE downloaded_at IS NOT NULL"
)
cursor = conn.execute("SELECT * FROM posts WHERE downloaded_at IS NOT NULL")
posts = []
for row in cursor.fetchall():
posts.append(PostRecord(
id=row["id"],
subreddit=row["subreddit"],
author=row["author"],
title=row["title"],
url=row["url"],
media_url=row["media_url"],
media_type=row["media_type"],
score=row["score"],
created_utc=row["created_utc"],
downloaded_at=row["downloaded_at"],
local_path=row["local_path"],
file_hash=row["file_hash"],
permalink=row["permalink"] if "permalink" in row.keys() else None,
source_type=row["source_type"] if "source_type" in row.keys() else None,
flair=row["flair"] if "flair" in row.keys() else None,
))
posts.append(
PostRecord(
id=row["id"],
subreddit=row["subreddit"],
author=row["author"],
title=row["title"],
url=row["url"],
media_url=row["media_url"],
media_type=row["media_type"],
score=row["score"],
created_utc=row["created_utc"],
downloaded_at=row["downloaded_at"],
local_path=row["local_path"],
file_hash=row["file_hash"],
permalink=row["permalink"],
source_type=row["source_type"],
flair=row["flair"],
)
)
return posts
def update_local_path(self, post_id: str, new_path: str) -> None:
"""Update local_path for a post (used in migration)."""
with self._get_connection() as conn:
conn.execute(
"UPDATE posts SET local_path = ? WHERE id = ?",
(new_path, post_id)
)
conn.execute("UPDATE posts SET local_path = ? WHERE id = ?", (new_path, post_id))
conn.commit()
def get_stats(self) -> dict:
"""Get collection statistics."""
with self._get_connection() as conn:
total = conn.execute("SELECT COUNT(*) FROM posts").fetchone()[0]
downloaded = conn.execute(
"SELECT COUNT(*) FROM posts WHERE downloaded_at IS NOT NULL"
).fetchone()[0]
downloaded = conn.execute("SELECT COUNT(*) FROM posts WHERE downloaded_at IS NOT NULL").fetchone()[0]
# Group by source: subreddits as "r/name", users as "u/name"
by_source = {}
@ -312,30 +289,30 @@ class Database:
month_start = today_start - timedelta(days=30)
# Format dates for SQLite comparison (space separator, not 'T')
today_str = today_start.strftime('%Y-%m-%d %H:%M:%S')
week_str = week_start.strftime('%Y-%m-%d %H:%M:%S')
month_str = month_start.strftime('%Y-%m-%d %H:%M:%S')
today_str = today_start.strftime("%Y-%m-%d %H:%M:%S")
week_str = week_start.strftime("%Y-%m-%d %H:%M:%S")
month_str = month_start.strftime("%Y-%m-%d %H:%M:%S")
# Downloads by period
downloads_today = conn.execute(
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?",
(today_str,)
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?", (today_str,)
).fetchone()[0]
downloads_week = conn.execute(
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?",
(week_str,)
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?", (week_str,)
).fetchone()[0]
downloads_month = conn.execute(
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?",
(month_str,)
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?", (month_str,)
).fetchone()[0]
# Average score
avg_score = conn.execute(
"SELECT AVG(score) FROM posts WHERE downloaded_at IS NOT NULL AND score IS NOT NULL"
).fetchone()[0] or 0
avg_score = (
conn.execute(
"SELECT AVG(score) FROM posts WHERE downloaded_at IS NOT NULL AND score IS NOT NULL"
).fetchone()[0]
or 0
)
# Unique authors
unique_authors = conn.execute(
@ -343,9 +320,7 @@ class Database:
).fetchone()[0]
# Favorites count
favorites_count = conn.execute(
"SELECT COUNT(*) FROM favorites"
).fetchone()[0]
favorites_count = conn.execute("SELECT COUNT(*) FROM favorites").fetchone()[0]
# Last download
last_download = conn.execute(
@ -358,18 +333,16 @@ class Database:
"SELECT downloaded_at FROM posts WHERE downloaded_at IS NOT NULL ORDER BY downloaded_at ASC LIMIT 1"
).fetchone()
total_downloaded = conn.execute(
"SELECT COUNT(*) FROM posts WHERE downloaded_at IS NOT NULL"
).fetchone()[0]
total_downloaded = conn.execute("SELECT COUNT(*) FROM posts WHERE downloaded_at IS NOT NULL").fetchone()[0]
# Calculate average per day
avg_per_day = 0
if first_download and first_download[0]:
try:
first_date = datetime.fromisoformat(first_download[0].replace('Z', '+00:00'))
first_date = datetime.fromisoformat(first_download[0].replace("Z", "+00:00"))
days_active = max((now - first_date.replace(tzinfo=None)).days, 1)
avg_per_day = round(total_downloaded / days_active, 1)
except:
except (ValueError, TypeError):
avg_per_day = 0
# Top 10 authors
@ -389,16 +362,12 @@ class Database:
for i in range(13, -1, -1):
day = today_start - timedelta(days=i)
day_end = day + timedelta(days=1)
day_str = day.strftime('%Y-%m-%d %H:%M:%S')
day_end_str = day_end.strftime('%Y-%m-%d %H:%M:%S')
day_str = day.strftime("%Y-%m-%d %H:%M:%S")
day_end_str = day_end.strftime("%Y-%m-%d %H:%M:%S")
count = conn.execute(
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ? AND downloaded_at < ?",
(day_str, day_end_str)
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ? AND downloaded_at < ?", (day_str, day_end_str)
).fetchone()[0]
trend_data.append({
"date": day.strftime("%m/%d"),
"count": count
})
trend_data.append({"date": day.strftime("%m/%d"), "count": count})
return {
"downloads_today": downloads_today,
@ -410,7 +379,7 @@ class Database:
"last_download": last_download,
"avg_per_day": avg_per_day,
"top_authors": [{"author": a, "count": c} for a, c in top_authors],
"trend": trend_data
"trend": trend_data,
}
def get_recent_downloads(self, limit: int = 10) -> list[dict]:
@ -425,13 +394,19 @@ class Database:
ORDER BY downloaded_at DESC
LIMIT ?
""",
(limit,)
(limit,),
)
return [dict(row) for row in cursor.fetchall()]
def get_media_files(self, limit: int = 50, offset: int = 0,
subreddit: str = None, media_type: str = None,
sort: str = "newest", author: str = None) -> list[dict]:
def get_media_files(
self,
limit: int = 50,
offset: int = 0,
subreddit: str | None = None,
media_type: str | None = None,
sort: str = "newest",
author: str | None = None,
) -> list[dict]:
"""Get media files with optional filtering and sorting.
Args:
@ -474,7 +449,9 @@ class Database:
cursor = conn.execute(query, params)
return [dict(row) for row in cursor.fetchall()]
def get_total_media_count(self, subreddit: str = None, media_type: str = None, author: str = None) -> int:
def get_total_media_count(
self, subreddit: str | None = None, media_type: str | None = None, author: str | None = None
) -> int:
"""Get total count of media files with optional filtering."""
with self._get_connection() as conn:
query = """
@ -534,7 +511,7 @@ class Database:
with self._get_connection() as conn:
# Create placeholders for IN clause
placeholders = ','.join('?' * len(authors))
placeholders = ",".join("?" * len(authors))
# Convert authors to lowercase for case-insensitive matching
authors_lower = [a.lower() for a in authors]
@ -545,28 +522,30 @@ class Database:
AND local_path IS NOT NULL
AND downloaded_at IS NOT NULL
""",
authors_lower
authors_lower,
)
posts = []
for row in cursor.fetchall():
posts.append(PostRecord(
id=row["id"],
subreddit=row["subreddit"],
author=row["author"],
title=row["title"],
url=row["url"],
media_url=row["media_url"],
media_type=row["media_type"],
score=row["score"],
created_utc=row["created_utc"],
downloaded_at=row["downloaded_at"],
local_path=row["local_path"],
file_hash=row["file_hash"],
permalink=row["permalink"] if "permalink" in row.keys() else None,
source_type=row["source_type"] if "source_type" in row.keys() else None,
flair=row["flair"] if "flair" in row.keys() else None,
))
posts.append(
PostRecord(
id=row["id"],
subreddit=row["subreddit"],
author=row["author"],
title=row["title"],
url=row["url"],
media_url=row["media_url"],
media_type=row["media_type"],
score=row["score"],
created_utc=row["created_utc"],
downloaded_at=row["downloaded_at"],
local_path=row["local_path"],
file_hash=row["file_hash"],
permalink=row["permalink"],
source_type=row["source_type"],
flair=row["flair"],
)
)
return posts
def count_posts_by_authors(self, authors: list[str]) -> int:
@ -575,7 +554,7 @@ class Database:
return 0
with self._get_connection() as conn:
placeholders = ','.join('?' * len(authors))
placeholders = ",".join("?" * len(authors))
authors_lower = [a.lower() for a in authors]
cursor = conn.execute(
@ -585,7 +564,7 @@ class Database:
AND local_path IS NOT NULL
AND downloaded_at IS NOT NULL
""",
authors_lower
authors_lower,
)
return cursor.fetchone()[0]
@ -595,7 +574,7 @@ class Database:
return []
with self._get_connection() as conn:
placeholders = ','.join('?' * len(subreddits))
placeholders = ",".join("?" * len(subreddits))
subreddits_lower = [s.lower() for s in subreddits]
cursor = conn.execute(
@ -605,28 +584,30 @@ class Database:
AND local_path IS NOT NULL
AND downloaded_at IS NOT NULL
""",
subreddits_lower
subreddits_lower,
)
posts = []
for row in cursor.fetchall():
posts.append(PostRecord(
id=row["id"],
subreddit=row["subreddit"],
author=row["author"],
title=row["title"],
url=row["url"],
media_url=row["media_url"],
media_type=row["media_type"],
score=row["score"],
created_utc=row["created_utc"],
downloaded_at=row["downloaded_at"],
local_path=row["local_path"],
file_hash=row["file_hash"],
permalink=row["permalink"] if "permalink" in row.keys() else None,
source_type=row["source_type"] if "source_type" in row.keys() else None,
flair=row["flair"] if "flair" in row.keys() else None,
))
posts.append(
PostRecord(
id=row["id"],
subreddit=row["subreddit"],
author=row["author"],
title=row["title"],
url=row["url"],
media_url=row["media_url"],
media_type=row["media_type"],
score=row["score"],
created_utc=row["created_utc"],
downloaded_at=row["downloaded_at"],
local_path=row["local_path"],
file_hash=row["file_hash"],
permalink=row["permalink"],
source_type=row["source_type"],
flair=row["flair"],
)
)
return posts
def count_posts_by_subreddits(self, subreddits: list[str]) -> int:
@ -635,7 +616,7 @@ class Database:
return 0
with self._get_connection() as conn:
placeholders = ','.join('?' * len(subreddits))
placeholders = ",".join("?" * len(subreddits))
subreddits_lower = [s.lower() for s in subreddits]
cursor = conn.execute(
@ -645,7 +626,7 @@ class Database:
AND local_path IS NOT NULL
AND downloaded_at IS NOT NULL
""",
subreddits_lower
subreddits_lower,
)
return cursor.fetchone()[0]
@ -655,10 +636,7 @@ class Database:
"""Add a post to favorites. Returns True if added, False if already exists."""
with self._get_connection() as conn:
try:
conn.execute(
"INSERT INTO favorites (post_id) VALUES (?)",
(post_id,)
)
conn.execute("INSERT INTO favorites (post_id) VALUES (?)", (post_id,))
conn.commit()
return True
except sqlite3.IntegrityError:
@ -667,20 +645,14 @@ class Database:
def remove_favorite(self, post_id: str) -> bool:
"""Remove a post from favorites. Returns True if removed."""
with self._get_connection() as conn:
cursor = conn.execute(
"DELETE FROM favorites WHERE post_id = ?",
(post_id,)
)
cursor = conn.execute("DELETE FROM favorites WHERE post_id = ?", (post_id,))
conn.commit()
return cursor.rowcount > 0
def is_favorite(self, post_id: str) -> bool:
"""Check if a post is favorited."""
with self._get_connection() as conn:
cursor = conn.execute(
"SELECT 1 FROM favorites WHERE post_id = ?",
(post_id,)
)
cursor = conn.execute("SELECT 1 FROM favorites WHERE post_id = ?", (post_id,))
return cursor.fetchone() is not None
def get_favorites(self, limit: int = 50, offset: int = 0) -> list[dict]:
@ -696,7 +668,7 @@ class Database:
ORDER BY f.favorited_at DESC
LIMIT ? OFFSET ?
""",
(limit, offset)
(limit, offset),
)
return [dict(row) for row in cursor.fetchall()]
@ -727,16 +699,16 @@ class Database:
authors: list[str],
limit: int = 50,
offset: int = 0,
subreddit: str = None,
media_type: str = None,
sort: str = "newest"
subreddit: str | None = None,
media_type: str | None = None,
sort: str = "newest",
) -> list[dict]:
"""Get media files from specific authors."""
if not authors:
return []
with self._get_connection() as conn:
placeholders = ','.join('?' * len(authors))
placeholders = ",".join("?" * len(authors))
authors_lower = [a.lower() for a in authors]
query = f"""
@ -774,17 +746,14 @@ class Database:
return [dict(row) for row in cursor.fetchall()]
def count_media_by_authors(
self,
authors: list[str],
subreddit: str = None,
media_type: str = None
self, authors: list[str], subreddit: str | None = None, media_type: str | None = None
) -> int:
"""Count media files from specific authors."""
if not authors:
return 0
with self._get_connection() as conn:
placeholders = ','.join('?' * len(authors))
placeholders = ",".join("?" * len(authors))
authors_lower = [a.lower() for a in authors]
query = f"""
@ -807,11 +776,7 @@ class Database:
return cursor.fetchone()[0]
def get_authors_with_stats(
self,
limit: int = 50,
offset: int = 0,
favorites_only: bool = False,
sort: str = "count"
self, limit: int = 50, offset: int = 0, favorites_only: bool = False, sort: str = "count"
) -> list[dict]:
"""Get list of authors with their media counts and a sample thumbnail.
@ -885,22 +850,27 @@ class Database:
for row in rows:
author_name = row[0]
# Check if author has any favorited posts
fav_cursor = conn.execute("""
fav_cursor = conn.execute(
"""
SELECT COUNT(*) FROM favorites f
JOIN posts p ON f.post_id = p.id
WHERE p.author = ?
""", (author_name,))
""",
(author_name,),
)
fav_count = fav_cursor.fetchone()[0]
authors.append({
"author": author_name,
"media_count": row[1],
"max_score": row[2],
"total_score": row[3],
"latest_post": row[4],
"thumb_path": row[5],
"is_favorite": fav_count > 0
})
authors.append(
{
"author": author_name,
"media_count": row[1],
"max_score": row[2],
"total_score": row[3],
"latest_post": row[4],
"thumb_path": row[5],
"is_favorite": fav_count > 0,
}
)
return authors
@ -939,8 +909,7 @@ class Database:
"""Start a new scheduler run. Returns the run ID."""
with self._get_connection() as conn:
cursor = conn.execute(
"INSERT INTO scheduler_history (started_at, status) VALUES (?, 'running')",
(started_at,)
"INSERT INTO scheduler_history (started_at, status) VALUES (?, 'running')", (started_at,)
)
conn.commit()
return cursor.lastrowid
@ -951,7 +920,7 @@ class Database:
status: str,
posts_processed: int = 0,
posts_downloaded: int = 0,
error_message: str = None
error_message: str | None = None,
) -> None:
"""Finish a scheduler run with results."""
with self._get_connection() as conn:
@ -962,7 +931,7 @@ class Database:
posts_downloaded = ?, error_message = ?
WHERE id = ?
""",
(datetime.now(), status, posts_processed, posts_downloaded, error_message, run_id)
(datetime.now(), status, posts_processed, posts_downloaded, error_message, run_id),
)
conn.commit()
@ -977,11 +946,11 @@ class Database:
ORDER BY started_at DESC
LIMIT ?
""",
(limit,)
(limit,),
)
return [dict(row) for row in cursor.fetchall()]
def get_last_scheduler_run(self) -> Optional[dict]:
def get_last_scheduler_run(self) -> dict | None:
"""Get the most recent scheduler run."""
with self._get_connection() as conn:
cursor = conn.execute(

View file

@ -6,7 +6,6 @@ import mimetypes
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import requests
@ -20,6 +19,7 @@ logger = logging.getLogger("reddit_collector")
@dataclass
class DownloadMetadata:
"""Metadata needed for downloading and sidecar generation."""
subreddit: str
author: str
title: str
@ -27,10 +27,11 @@ class DownloadMetadata:
created_utc: float
post_id: str
media_type: str
gallery_index: Optional[int] = None
permalink: Optional[str] = None
flair: Optional[str] = None
source_type: Optional[str] = None
gallery_index: int | None = None
permalink: str | None = None
flair: str | None = None
source_type: str | None = None
MIME_TO_EXT = {
"image/jpeg": ".jpg",
@ -60,17 +61,19 @@ class Downloader:
def _create_session(self) -> requests.Session:
"""Create a requests session with appropriate headers."""
session = requests.Session()
session.headers.update({
"User-Agent": "Mozilla/5.0 (compatible; RedditMediaCollector/1.0)",
"Accept": "image/*,video/*,*/*",
})
session.headers.update(
{
"User-Agent": "Mozilla/5.0 (compatible; RedditMediaCollector/1.0)",
"Accept": "image/*,video/*,*/*",
}
)
return session
def download(
self,
url: str,
metadata: DownloadMetadata,
) -> tuple[Optional[str], Optional[str]]:
) -> tuple[str | None, str | None]:
"""
Download media from URL.
Returns (local_path, file_hash) or (None, None) on failure.
@ -117,10 +120,9 @@ class Downloader:
except requests.exceptions.RequestException as e:
last_error = e
if attempt < max_retries - 1:
wait_time = 2 ** attempt
wait_time = 2**attempt
logger.warning(
f"Download failed (attempt {attempt + 1}/{max_retries}), "
f"retrying in {wait_time}s: {e}"
f"Download failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s: {e}"
)
time.sleep(wait_time)
@ -137,9 +139,7 @@ class Downloader:
content_length = response.headers.get("Content-Length")
if content_length and int(content_length) > self.max_size:
raise ValueError(
f"File too large: {int(content_length) / 1024 / 1024:.1f}MB"
)
raise ValueError(f"File too large: {int(content_length) / 1024 / 1024:.1f}MB")
ext = self._get_extension(url, response.headers.get("Content-Type"))
@ -184,9 +184,7 @@ class Downloader:
return str(local_path), file_hash
def _get_extension(
self, url: str, content_type: Optional[str]
) -> str:
def _get_extension(self, url: str, content_type: str | None) -> str:
"""Determine file extension from URL or Content-Type."""
if content_type:
content_type = content_type.split(";")[0].strip()

View file

@ -1,12 +1,11 @@
"""Media URL extractors for various hosts."""
import logging
from typing import Optional
from urllib.parse import urlparse
from .gfycat import extract_gfycat_url
from .imgur import extract_imgur_url
from .reddit import extract_reddit_video_url
from .gfycat import extract_gfycat_url
logger = logging.getLogger("reddit_collector")

View file

@ -1,12 +1,11 @@
"""Gfycat/Redgifs URL extractor using yt-dlp."""
import logging
from typing import Optional
logger = logging.getLogger("reddit_collector")
def extract_gfycat_url(url: str) -> Optional[str]:
def extract_gfycat_url(url: str) -> str | None:
"""
Extract video URL from Gfycat/Redgifs links.
Uses yt-dlp for extraction.
@ -26,10 +25,7 @@ def extract_gfycat_url(url: str) -> Optional[str]:
return info["url"]
if info and "formats" in info:
mp4_formats = [
f for f in info["formats"]
if f.get("ext") == "mp4" and f.get("url")
]
mp4_formats = [f for f in info["formats"] if f.get("ext") == "mp4" and f.get("url")]
if mp4_formats:
best = max(mp4_formats, key=lambda x: x.get("height", 0))
return best["url"]

View file

@ -2,13 +2,12 @@
import logging
import re
from typing import Optional
from urllib.parse import urlparse
logger = logging.getLogger("reddit_collector")
def extract_imgur_url(url: str) -> tuple[Optional[str], str]:
def extract_imgur_url(url: str) -> tuple[str | None, str]:
"""
Extract direct image/video URL from Imgur links.
Returns (url, media_type) or (None, "image") on failure.

View file

@ -1,12 +1,11 @@
"""Reddit video (v.redd.it) URL extractor using yt-dlp."""
import logging
from typing import Optional
logger = logging.getLogger("reddit_collector")
def extract_reddit_video_url(url: str) -> Optional[str]:
def extract_reddit_video_url(url: str) -> str | None:
"""
Extract video URL from v.redd.it links.
Uses yt-dlp to get the actual video URL.
@ -27,10 +26,7 @@ def extract_reddit_video_url(url: str) -> Optional[str]:
return info["url"]
if info and "formats" in info:
mp4_formats = [
f for f in info["formats"]
if f.get("ext") == "mp4" and f.get("url")
]
mp4_formats = [f for f in info["formats"] if f.get("ext") == "mp4" and f.get("url")]
if mp4_formats:
best = max(mp4_formats, key=lambda x: x.get("height", 0))
return best["url"]

View file

@ -57,13 +57,10 @@ def is_domain_blacklisted(url: str, blacklist_domains: list[str]) -> bool:
return False
url_lower = url.lower()
for domain in blacklist_domains:
if domain in url_lower:
return True
return False
return any(domain in url_lower for domain in blacklist_domains)
def should_download_media(media_type: str, config: Config, author: str = None, db: Database = None) -> bool:
def should_download_media(media_type: str, config: Config, author: str | None = None, db: Database = None) -> bool:
"""Check if media type is allowed. For videos, optionally check if author is favorited."""
if media_type not in config.download.media_types:
return False
@ -116,10 +113,7 @@ def process_post(
# Process each media item
for idx, (media_url, media_type) in enumerate(media_urls):
# Generate unique ID for gallery items
if len(media_urls) > 1:
item_id = f"{post.id}_{idx + 1}"
else:
item_id = post.id
item_id = f"{post.id}_{idx + 1}" if len(media_urls) > 1 else post.id
# Skip if already in database
if db.post_exists(item_id):
@ -178,12 +172,12 @@ def process_post(
# Correct media_type based on actual file extension
ext = os.path.splitext(local_path)[1].lower()
actual_type = final_type
if ext in ('.jpg', '.jpeg', '.png', '.webp'):
actual_type = 'image'
elif ext == '.gif':
actual_type = 'gif'
elif ext in ('.mp4', '.webm', '.mov'):
actual_type = 'video'
if ext in (".jpg", ".jpeg", ".png", ".webp"):
actual_type = "image"
elif ext == ".gif":
actual_type = "gif"
elif ext in (".mp4", ".webm", ".mov"):
actual_type = "video"
if actual_type != final_type:
logger.debug(f"Corrected media_type for {item_id}: {final_type} -> {actual_type}")
@ -211,14 +205,9 @@ def process_post(
stats.downloaded += 1
if len(media_urls) > 1:
logger.info(
f"Downloaded: {item_id} from r/{post.subreddit} ({actual_type}) "
f"[{idx + 1}/{len(media_urls)}]"
)
logger.info(f"Downloaded: {item_id} from r/{post.subreddit} ({actual_type}) [{idx + 1}/{len(media_urls)}]")
else:
logger.info(
f"Downloaded: {item_id} from r/{post.subreddit} ({actual_type})"
)
logger.info(f"Downloaded: {item_id} from r/{post.subreddit} ({actual_type})")
def collect(config: Config, logger) -> CollectionStats:
@ -233,10 +222,7 @@ def collect(config: Config, logger) -> CollectionStats:
logger.info(f"Processing subreddit: r/{target.name}")
try:
for post in client.get_subreddit_posts(target):
process_post(
post, client, db, downloader, config, stats, logger,
source_type="subreddit"
)
process_post(post, client, db, downloader, config, stats, logger, source_type="subreddit")
except Exception as e:
logger.error(f"Error processing r/{target.name}: {e}")
stats.errors += 1
@ -245,10 +231,7 @@ def collect(config: Config, logger) -> CollectionStats:
logger.info(f"Processing user: u/{target.name}")
try:
for post in client.get_user_posts(target):
process_post(
post, client, db, downloader, config, stats, logger,
source_type="user"
)
process_post(post, client, db, downloader, config, stats, logger, source_type="user")
except Exception as e:
logger.error(f"Error processing u/{target.name}: {e}")
stats.errors += 1
@ -289,11 +272,10 @@ def print_report(stats: CollectionStats, db: Database, logger) -> None:
def main():
"""Entry point."""
parser = argparse.ArgumentParser(
description="Collect images and videos from Reddit"
)
parser = argparse.ArgumentParser(description="Collect images and videos from Reddit")
parser.add_argument(
"-c", "--config",
"-c",
"--config",
default="config.yaml",
help="Path to configuration file (default: config.yaml)",
)

View file

@ -2,8 +2,8 @@
import logging
import time
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Iterator, Optional
import requests
@ -17,6 +17,7 @@ BASE_URL = "https://www.reddit.com"
@dataclass
class Post:
"""Represents a Reddit post."""
id: str
subreddit: str
author: str
@ -26,10 +27,10 @@ class Post:
created_utc: float
over_18: bool
is_gallery: bool
preview: Optional[dict]
media_metadata: Optional[dict]
permalink: Optional[str] = None # Reddit permalink
flair: Optional[str] = None # Post flair text
preview: dict | None
media_metadata: dict | None
permalink: str | None = None # Reddit permalink
flair: str | None = None # Post flair text
class RateLimiter:
@ -55,14 +56,16 @@ class RedditClient:
def __init__(self, rate_config: RateLimitConfig):
self.rate_limiter = RateLimiter(rate_config.requests_per_minute)
self.session = requests.Session()
self.session.headers.update({
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36",
})
self.session.headers.update(
{
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36",
}
)
logger.info("Reddit client initialized (public JSON API)")
def _fetch_json(self, url: str, params: dict = None) -> dict:
def _fetch_json(self, url: str, params: dict | None = None) -> dict:
"""Fetch JSON from Reddit with rate limiting and error handling."""
self.rate_limiter.wait()
@ -105,14 +108,9 @@ class RedditClient:
flair=flair,
)
def get_subreddit_posts(
self, target: SubredditTarget
) -> Iterator[Post]:
def get_subreddit_posts(self, target: SubredditTarget) -> Iterator[Post]:
"""Fetch posts from a subreddit."""
logger.info(
f"Fetching {target.limit} posts from r/{target.name} "
f"(sort: {target.sort})"
)
logger.info(f"Fetching {target.limit} posts from r/{target.name} (sort: {target.sort})")
url = f"{BASE_URL}/r/{target.name}/{target.sort}.json"
params = {"limit": min(target.limit, 100)}
@ -224,7 +222,7 @@ class RedditClient:
if not post.media_metadata:
return urls
for item_id, item in post.media_metadata.items():
for _item_id, item in post.media_metadata.items():
if item.get("status") == "valid" and item.get("e") == "Image":
source = item.get("s", {})
if "u" in source:

View file

@ -2,9 +2,7 @@
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
from datetime import UTC, datetime
logger = logging.getLogger("reddit_collector")
@ -17,9 +15,9 @@ def write_immich_sidecar(
score: int,
created_utc: float,
media_type: str,
permalink: Optional[str] = None,
flair: Optional[str] = None,
source_type: Optional[str] = None,
permalink: str | None = None,
flair: str | None = None,
source_type: str | None = None,
) -> str:
"""
Write a JSON sidecar file for Immich.
@ -45,7 +43,7 @@ def write_immich_sidecar(
sidecar_path = f"{filepath}.json"
# Convert timestamp to ISO format
dt = datetime.fromtimestamp(created_utc, tz=timezone.utc)
dt = datetime.fromtimestamp(created_utc, tz=UTC)
date_iso = dt.isoformat()
# Build tags list
@ -99,7 +97,7 @@ def generate_filename(
created_utc: float,
post_id: str,
ext: str,
gallery_index: Optional[int] = None,
gallery_index: int | None = None,
) -> str:
"""
Generate a descriptive filename with metadata embedded.
@ -122,7 +120,7 @@ def generate_filename(
Generated filename
"""
# Convert timestamp to datetime
dt = datetime.fromtimestamp(created_utc, tz=timezone.utc)
dt = datetime.fromtimestamp(created_utc, tz=UTC)
date_str = dt.strftime("%Y%m%d_%H%M%S")
# Sanitize components

File diff suppressed because it is too large Load diff

View file

@ -1,12 +1,12 @@
"""Configuration file manager for CRUD operations."""
import os
from pathlib import Path
from typing import Any
import yaml
CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml"
CONFIG_PATH = Path(os.environ.get("RMC_CONFIG_PATH", str(Path(__file__).parent.parent.parent / "config.yaml")))
def load_config() -> dict[str, Any]:
@ -14,7 +14,7 @@ def load_config() -> dict[str, Any]:
if not CONFIG_PATH.exists():
return {"targets": {"subreddits": [], "users": []}}
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
with open(CONFIG_PATH, encoding="utf-8") as f:
return yaml.safe_load(f) or {}
@ -44,11 +44,7 @@ def add_subreddit(name: str, limit: int = 100, sort: str = "new") -> bool:
if sub["name"].lower() == name.lower():
return False
config["targets"]["subreddits"].append({
"name": name,
"limit": limit,
"sort": sort
})
config["targets"]["subreddits"].append({"name": name, "limit": limit, "sort": sort})
save_config(config)
return True
@ -60,9 +56,7 @@ def remove_subreddit(name: str) -> bool:
subreddits = config.get("targets", {}).get("subreddits", [])
original_len = len(subreddits)
config["targets"]["subreddits"] = [
s for s in subreddits if s["name"].lower() != name.lower()
]
config["targets"]["subreddits"] = [s for s in subreddits if s["name"].lower() != name.lower()]
if len(config["targets"]["subreddits"]) < original_len:
save_config(config)
@ -90,10 +84,7 @@ def add_user(name: str, limit: int = 100) -> bool:
if user["name"].lower() == name.lower():
return False
config["targets"]["users"].append({
"name": name,
"limit": limit
})
config["targets"]["users"].append({"name": name, "limit": limit})
save_config(config)
return True
@ -105,9 +96,7 @@ def remove_user(name: str) -> bool:
users = config.get("targets", {}).get("users", [])
original_len = len(users)
config["targets"]["users"] = [
u for u in users if u["name"].lower() != name.lower()
]
config["targets"]["users"] = [u for u in users if u["name"].lower() != name.lower()]
if len(config["targets"]["users"]) < original_len:
save_config(config)
@ -117,6 +106,7 @@ def remove_user(name: str) -> bool:
# Blacklist functions
def _ensure_blacklist(config: dict) -> dict:
"""Ensure blacklist structure exists in config."""
if "blacklist" not in config:
@ -148,8 +138,7 @@ def add_blacklist_author(author: str) -> bool:
# Also remove from users collection if present
if "targets" in config and "users" in config["targets"]:
config["targets"]["users"] = [
u for u in config["targets"]["users"]
if u.get("name", "").lower() != author.lower()
u for u in config["targets"]["users"] if u.get("name", "").lower() != author.lower()
]
save_config(config)
@ -162,9 +151,7 @@ def remove_blacklist_author(author: str) -> bool:
config = _ensure_blacklist(config)
original_len = len(config["blacklist"]["authors"])
config["blacklist"]["authors"] = [
a for a in config["blacklist"]["authors"] if a.lower() != author.lower()
]
config["blacklist"]["authors"] = [a for a in config["blacklist"]["authors"] if a.lower() != author.lower()]
if len(config["blacklist"]["authors"]) < original_len:
save_config(config)
@ -186,8 +173,7 @@ def add_blacklist_subreddit(subreddit: str) -> bool:
# Also remove from subreddits collection if present
if "targets" in config and "subreddits" in config["targets"]:
config["targets"]["subreddits"] = [
s for s in config["targets"]["subreddits"]
if s.get("name", "").lower() != subreddit.lower()
s for s in config["targets"]["subreddits"] if s.get("name", "").lower() != subreddit.lower()
]
save_config(config)
@ -200,9 +186,7 @@ def remove_blacklist_subreddit(subreddit: str) -> bool:
config = _ensure_blacklist(config)
original_len = len(config["blacklist"]["subreddits"])
config["blacklist"]["subreddits"] = [
s for s in config["blacklist"]["subreddits"] if s.lower() != subreddit.lower()
]
config["blacklist"]["subreddits"] = [s for s in config["blacklist"]["subreddits"] if s.lower() != subreddit.lower()]
if len(config["blacklist"]["subreddits"]) < original_len:
save_config(config)
@ -266,9 +250,7 @@ def remove_blacklist_domain(domain: str) -> bool:
domain = domain.lower().replace("https://", "").replace("http://", "").strip("/")
original_len = len(config["blacklist"]["domains"])
config["blacklist"]["domains"] = [
d for d in config["blacklist"]["domains"] if d != domain
]
config["blacklist"]["domains"] = [d for d in config["blacklist"]["domains"] if d != domain]
if len(config["blacklist"]["domains"]) < original_len:
save_config(config)

29
src/web/deps.py Normal file
View file

@ -0,0 +1,29 @@
"""Shared dependencies and state for web application."""
import os
import threading
from pathlib import Path
# Project root directory
PROJECT_DIR = Path(__file__).parent.parent.parent
# Downloads directory for serving media files
DOWNLOADS_DIR = os.environ.get("RMC_DOWNLOAD_DIR", str(PROJECT_DIR / "downloads"))
DOWNLOADS_DIR = Path(DOWNLOADS_DIR)
# Thumbnails directory
THUMBS_DIR = DOWNLOADS_DIR / ".thumbs"
THUMBS_DIR.mkdir(parents=True, exist_ok=True)
# Timezone from env var, default to UTC
TIMEZONE = os.environ.get("RMC_TIMEZONE", "UTC")
# Scheduler config file path
SCHEDULER_CONFIG_PATH = PROJECT_DIR / "scheduler_config.yaml"
# Scheduler DB path
SCHEDULER_DB_PATH = PROJECT_DIR / "scheduler.db"
# Collector state (protected by lock for thread safety)
collector_lock = threading.Lock()
collector_status: dict = {"running": False, "last_run": None, "last_result": None}

View file

@ -0,0 +1 @@
"""FastAPI router modules."""

246
src/web/routers/config.py Normal file
View file

@ -0,0 +1,246 @@
"""Configuration and blacklist API routes."""
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from .. import config_manager
router = APIRouter(tags=["config"])
class SubredditCreate(BaseModel):
name: str
limit: int = 100
sort: str = "new"
class UserCreate(BaseModel):
name: str
limit: int = 100
class BlacklistItem(BaseModel):
value: str
@router.get("/api/config")
async def get_config():
"""Get full configuration."""
return config_manager.load_config()
@router.get("/api/subreddits")
async def list_subreddits():
"""List all subreddits."""
return config_manager.get_subreddits()
@router.post("/api/subreddits")
async def add_subreddit(data: SubredditCreate):
"""Add a new subreddit."""
if not data.name:
raise HTTPException(status_code=400, detail="Name is required")
success = config_manager.add_subreddit(data.name, data.limit, data.sort)
if not success:
raise HTTPException(status_code=409, detail="Subreddit already exists")
return {"message": f"Subreddit '{data.name}' added successfully"}
@router.delete("/api/subreddits/{name}")
async def delete_subreddit(name: str):
"""Remove a subreddit."""
success = config_manager.remove_subreddit(name)
if not success:
raise HTTPException(status_code=404, detail="Subreddit not found")
return {"message": f"Subreddit '{name}' removed successfully"}
@router.get("/api/users")
async def list_users():
"""List all users."""
return config_manager.get_users()
@router.post("/api/users")
async def add_user(data: UserCreate):
"""Add a new user."""
if not data.name:
raise HTTPException(status_code=400, detail="Name is required")
success = config_manager.add_user(data.name, data.limit)
if not success:
raise HTTPException(status_code=409, detail="User already exists")
return {"message": f"User '{data.name}' added successfully"}
@router.delete("/api/users/{name}")
async def delete_user(name: str):
"""Remove a user."""
success = config_manager.remove_user(name)
if not success:
raise HTTPException(status_code=404, detail="User not found")
return {"message": f"User '{name}' removed successfully"}
# Blacklist endpoints
@router.get("/api/blacklist")
async def get_blacklist():
"""Get full blacklist configuration."""
return config_manager.get_blacklist()
@router.post("/api/blacklist/authors")
async def add_blacklist_author(data: BlacklistItem):
"""Add an author to the blacklist."""
if not data.value:
raise HTTPException(status_code=400, detail="Value is required")
success = config_manager.add_blacklist_author(data.value)
if not success:
raise HTTPException(status_code=409, detail="Author already blacklisted")
return {"message": f"Author '{data.value}' added to blacklist"}
@router.delete("/api/blacklist/authors/{author}")
async def remove_blacklist_author(author: str):
"""Remove an author from the blacklist."""
success = config_manager.remove_blacklist_author(author)
if not success:
raise HTTPException(status_code=404, detail="Author not found in blacklist")
return {"message": f"Author '{author}' removed from blacklist"}
@router.post("/api/blacklist/subreddits")
async def add_blacklist_subreddit(data: BlacklistItem):
"""Add a subreddit to the blacklist."""
if not data.value:
raise HTTPException(status_code=400, detail="Value is required")
success = config_manager.add_blacklist_subreddit(data.value)
if not success:
raise HTTPException(status_code=409, detail="Subreddit already blacklisted")
return {"message": f"Subreddit '{data.value}' added to blacklist"}
@router.delete("/api/blacklist/subreddits/{subreddit}")
async def remove_blacklist_subreddit(subreddit: str):
"""Remove a subreddit from the blacklist."""
success = config_manager.remove_blacklist_subreddit(subreddit)
if not success:
raise HTTPException(status_code=404, detail="Subreddit not found in blacklist")
return {"message": f"Subreddit '{subreddit}' removed from blacklist"}
@router.post("/api/blacklist/keywords")
async def add_blacklist_keyword(data: BlacklistItem):
"""Add a title keyword to the blacklist."""
if not data.value:
raise HTTPException(status_code=400, detail="Value is required")
success = config_manager.add_blacklist_keyword(data.value)
if not success:
raise HTTPException(status_code=409, detail="Keyword already blacklisted")
return {"message": f"Keyword '{data.value}' added to blacklist"}
@router.delete("/api/blacklist/keywords/{keyword:path}")
async def remove_blacklist_keyword(keyword: str):
"""Remove a title keyword from the blacklist."""
success = config_manager.remove_blacklist_keyword(keyword)
if not success:
raise HTTPException(status_code=404, detail="Keyword not found in blacklist")
return {"message": f"Keyword '{keyword}' removed from blacklist"}
@router.post("/api/blacklist/domains")
async def add_blacklist_domain(data: BlacklistItem):
"""Add a domain to the blacklist."""
if not data.value:
raise HTTPException(status_code=400, detail="Value is required")
success = config_manager.add_blacklist_domain(data.value)
if not success:
raise HTTPException(status_code=409, detail="Domain already blacklisted")
return {"message": f"Domain '{data.value}' added to blacklist"}
@router.delete("/api/blacklist/domains/{domain:path}")
async def remove_blacklist_domain(domain: str):
"""Remove a domain from the blacklist."""
success = config_manager.remove_blacklist_domain(domain)
if not success:
raise HTTPException(status_code=404, detail="Domain not found in blacklist")
return {"message": f"Domain '{domain}' removed from blacklist"}
# Settings endpoints
class DownloadSettings(BaseModel):
media_types: list[str] = ["image"]
min_score: int = 1
skip_nsfw: bool = False
max_file_size_mb: int = 200
videos_only_from_favorites: bool = False
@router.get("/api/settings")
async def get_settings():
"""Get current settings."""
config = config_manager.load_config()
return {
"download": config.get("download", {}),
"rate_limit": config.get("rate_limit", {}),
"blacklist": config.get("blacklist", {}),
}
@router.put("/api/settings/download")
async def update_download_settings(settings: DownloadSettings):
"""Update download settings."""
config = config_manager.load_config()
if "download" not in config:
config["download"] = {}
config["download"]["media_types"] = settings.media_types
config["download"]["min_score"] = settings.min_score
config["download"]["skip_nsfw"] = settings.skip_nsfw
config["download"]["max_file_size_mb"] = settings.max_file_size_mb
config["download"]["videos_only_from_favorites"] = settings.videos_only_from_favorites
config_manager.save_config(config)
return {"message": "Download settings updated", "settings": settings.model_dump()}
class RateLimitSettings(BaseModel):
requests_per_minute: int = 20
download_delay_seconds: float = 2.0
@router.put("/api/settings/rate-limit")
async def update_rate_limit_settings(settings: RateLimitSettings):
"""Update rate limit settings."""
config = config_manager.load_config()
if "rate_limit" not in config:
config["rate_limit"] = {}
config["rate_limit"]["requests_per_minute"] = settings.requests_per_minute
config["rate_limit"]["download_delay_seconds"] = settings.download_delay_seconds
config_manager.save_config(config)
return {"message": "Rate limit settings updated", "settings": settings.model_dump()}

View file

@ -0,0 +1,115 @@
"""Favorites and authors API routes."""
from fastapi import APIRouter, HTTPException, Query
from ...database import Database
from .. import config_manager
router = APIRouter(tags=["favorites"])
@router.get("/api/favorites")
async def get_favorites(limit: int = Query(default=50, le=200), offset: int = Query(default=0, ge=0)):
"""Get favorited posts."""
db = Database()
favorites = db.get_favorites(limit, offset)
total = db.count_favorites()
for fav in favorites:
fav["is_favorite"] = True
return {"favorites": favorites, "total": total, "limit": limit, "offset": offset}
@router.post("/api/favorites/{post_id}")
async def add_favorite(post_id: str, add_user_to_collection: bool = True):
"""Add a post to favorites. Optionally adds the author to user collection."""
db = Database()
post = db.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
added = db.add_favorite(post_id)
result = {
"message": f"Post '{post_id}' added to favorites" if added else "Post already in favorites",
"added": added,
}
if add_user_to_collection and post.author and post.author not in ("[deleted]", "AutoModerator"):
user_added = config_manager.add_user(post.author, limit=100)
if user_added:
result["user_added"] = post.author
result["message"] += f". User '{post.author}' added to collection targets."
return result
@router.delete("/api/favorites/{post_id}")
async def remove_favorite(post_id: str):
"""Remove a post from favorites."""
db = Database()
removed = db.remove_favorite(post_id)
if not removed:
raise HTTPException(status_code=404, detail="Post not in favorites")
return {"message": f"Post '{post_id}' removed from favorites"}
@router.get("/api/favorites/authors")
async def get_favorite_authors():
"""Get list of unique authors from favorited posts."""
db = Database()
return db.get_favorite_authors()
@router.post("/api/favorites/sync-users")
async def sync_favorite_authors_to_users():
"""Add all authors from favorites to user collection targets."""
db = Database()
authors = db.get_favorite_authors()
added = []
for author in authors:
if config_manager.add_user(author, limit=100):
added.append(author)
return {"synced": len(added), "added_users": added, "message": f"Added {len(added)} users to collection targets"}
# Authors endpoints
@router.get("/api/authors")
async def get_authors(
limit: int = Query(default=50, le=200),
offset: int = Query(default=0, ge=0),
favorites_only: bool = Query(default=False),
sort: str = Query(default="count"),
):
"""Get list of authors with stats and thumbnails."""
db = Database()
authors = db.get_authors_with_stats(limit, offset, favorites_only, sort)
total = db.count_authors(favorites_only)
return {"authors": authors, "total": total, "limit": limit, "offset": offset}
@router.get("/api/authors/{author}/media")
async def get_author_media(
author: str,
limit: int = Query(default=50, le=200),
offset: int = Query(default=0, ge=0),
sort: str = Query(default="newest"),
):
"""Get all media from a specific author."""
db = Database()
files = db.get_media_by_authors([author], limit, offset, sort=sort)
total = db.count_media_by_authors([author])
for f in files:
f["is_favorite"] = db.is_favorite(f["id"])
return {"files": files, "total": total, "author": author, "limit": limit, "offset": offset}

384
src/web/routers/media.py Normal file
View file

@ -0,0 +1,384 @@
"""Media browsing, serving, and cleanup API routes."""
import subprocess
from pathlib import Path
from fastapi import APIRouter, Header, HTTPException, Query
from fastapi.responses import FileResponse, StreamingResponse
from ...database import Database
from .. import config_manager
from ..deps import DOWNLOADS_DIR, THUMBS_DIR
router = APIRouter(tags=["media"])
@router.get("/api/media")
async def get_media_files(
limit: int = Query(default=50, le=200),
offset: int = Query(default=0, ge=0),
subreddit: str | None = None,
media_type: str | None = None,
author: str | None = None,
sort: str = Query(default="newest"),
favorites_only: bool = Query(default=False),
favorite_authors: bool = Query(default=False),
):
"""Get media files with pagination, filtering and sorting."""
db = Database()
fav_authors_list = db.get_favorite_authors()
fav_authors_set = {a.lower() for a in fav_authors_list} if fav_authors_list else set()
if favorites_only:
files = db.get_favorites(limit, offset)
total = db.count_favorites()
elif favorite_authors:
if fav_authors_list:
files = db.get_media_by_authors(fav_authors_list, limit, offset, subreddit, media_type)
total = db.count_media_by_authors(fav_authors_list, subreddit, media_type)
else:
files = []
total = 0
for f in files:
f["is_favorite"] = db.is_favorite(f["id"])
else:
files = db.get_media_files(limit, offset, subreddit, media_type, sort, author)
total = db.get_total_media_count(subreddit, media_type, author)
for f in files:
f["is_favorite"] = db.is_favorite(f["id"])
for f in files:
file_author = f.get("author", "")
f["is_author_favorite"] = file_author.lower() in fav_authors_set if file_author else False
return {"files": files, "total": total, "limit": limit, "offset": offset}
@router.get("/api/media/subreddits")
async def get_media_subreddits():
"""Get list of subreddits with downloaded content."""
db = Database()
return db.get_all_subreddits()
@router.get("/api/media/authors")
async def get_media_authors():
"""Get list of authors with downloaded content."""
db = Database()
return db.get_all_authors()
@router.get("/api/media/file/{filename:path}")
async def get_media_file(filename: str, range: str | None = Header(None)):
"""Serve a media file with Range request support for video streaming."""
file_path = DOWNLOADS_DIR / filename
if not file_path.exists():
raise HTTPException(status_code=404, detail="File not found")
file_size = file_path.stat().st_size
suffix = file_path.suffix.lower()
media_types = {
".mp4": "video/mp4",
".webm": "video/webm",
".mov": "video/quicktime",
".avi": "video/x-msvideo",
".mkv": "video/x-matroska",
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".gif": "image/gif",
".webp": "image/webp",
}
media_type = media_types.get(suffix, "application/octet-stream")
if not range or suffix not in (".mp4", ".webm", ".mov", ".avi", ".mkv"):
return FileResponse(file_path, media_type=media_type, headers={"Accept-Ranges": "bytes"})
try:
range_str = range.replace("bytes=", "")
range_parts = range_str.split("-")
start = int(range_parts[0]) if range_parts[0] else 0
end = int(range_parts[1]) if range_parts[1] else file_size - 1
except (ValueError, IndexError):
start = 0
end = file_size - 1
start = max(0, min(start, file_size - 1))
end = max(start, min(end, file_size - 1))
content_length = end - start + 1
def iterfile():
with open(file_path, "rb") as f:
f.seek(start)
remaining = content_length
chunk_size = 64 * 1024
while remaining > 0:
read_size = min(chunk_size, remaining)
data = f.read(read_size)
if not data:
break
remaining -= len(data)
yield data
headers = {
"Content-Range": f"bytes {start}-{end}/{file_size}",
"Accept-Ranges": "bytes",
"Content-Length": str(content_length),
}
return StreamingResponse(iterfile(), status_code=206, media_type=media_type, headers=headers)
def _generate_thumbnail(video_path: Path) -> Path | None:
"""Generate a thumbnail for a video file using ffmpeg."""
thumb_path = THUMBS_DIR / f"{video_path.name}.jpg"
if thumb_path.exists():
return thumb_path
try:
for seek_time in ("00:00:01", "00:00:00"):
result = subprocess.run(
[
"ffmpeg",
"-y",
"-i",
str(video_path),
"-ss",
seek_time,
"-vframes",
"1",
"-vf",
"scale=320:320:force_original_aspect_ratio=decrease",
"-q:v",
"2",
str(thumb_path),
],
capture_output=True,
timeout=30,
)
if result.returncode == 0 and thumb_path.exists():
return thumb_path
return None
except (subprocess.TimeoutExpired, FileNotFoundError):
return None
@router.get("/api/media/thumb/{filename:path}")
async def get_video_thumbnail(filename: str):
"""Get or generate a thumbnail for a video file."""
video_path = DOWNLOADS_DIR / filename
if not video_path.exists():
raise HTTPException(status_code=404, detail="Video not found")
video_extensions = {".mp4", ".webm", ".mov", ".avi", ".mkv"}
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
if video_path.suffix.lower() in image_extensions:
return FileResponse(video_path, media_type="image/jpeg")
if video_path.suffix.lower() not in video_extensions:
raise HTTPException(status_code=400, detail="Not a video file")
thumb_path = _generate_thumbnail(video_path)
if thumb_path and thumb_path.exists():
return FileResponse(thumb_path, media_type="image/jpeg")
# Check magic bytes as fallback
try:
with open(video_path, "rb") as f:
header = f.read(12)
if header[:2] == b"\xff\xd8":
return FileResponse(video_path, media_type="image/jpeg")
if header[:8] == b"\x89PNG\r\n\x1a\n":
return FileResponse(video_path, media_type="image/png")
except Exception:
pass
raise HTTPException(status_code=500, detail="Failed to generate thumbnail")
@router.delete("/api/media/{post_id}")
async def delete_media(post_id: str, blacklist_author: bool = False, blacklist_subreddit: bool = False):
"""Delete a media file and its database record."""
db = Database()
post = db.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
blacklisted = []
if blacklist_author and post.author and post.author not in ("[deleted]", "AutoModerator"):
config_manager.add_blacklist_author(post.author)
blacklisted.append(f"author:{post.author}")
if blacklist_subreddit and post.subreddit:
config_manager.add_blacklist_subreddit(post.subreddit)
blacklisted.append(f"subreddit:{post.subreddit}")
if post.local_path:
_delete_media_files(Path(post.local_path))
with db._get_connection() as conn:
conn.execute("DELETE FROM posts WHERE id = ?", (post_id,))
conn.commit()
result = {"message": f"Media '{post_id}' deleted successfully"}
if blacklisted:
result["blacklisted"] = blacklisted
return result
@router.get("/api/media/{post_id}/info")
async def get_media_info(post_id: str):
"""Get media info for delete confirmation dialog."""
db = Database()
post = db.get_post(post_id)
if not post:
raise HTTPException(status_code=404, detail="Post not found")
return {
"id": post.id,
"author": post.author,
"subreddit": post.subreddit,
"title": post.title,
"media_type": post.media_type,
}
# Cleanup endpoints
@router.get("/api/media/blacklist-preview")
async def preview_blacklist_cleanup():
"""Preview how many files would be deleted by cleanup."""
blacklist = config_manager.get_blacklist()
authors = blacklist.get("authors", [])
subreddits = blacklist.get("subreddits", [])
db = Database()
author_count = db.count_posts_by_authors(authors) if authors else 0
subreddit_count = db.count_posts_by_subreddits(subreddits) if subreddits else 0
return {
"author_count": author_count,
"subreddit_count": subreddit_count,
"total_count": author_count + subreddit_count,
"authors": authors,
"subreddits": subreddits,
}
@router.post("/api/media/cleanup-blacklist")
async def cleanup_blacklisted_media():
"""Delete all media from blacklisted authors and subreddits."""
blacklist = config_manager.get_blacklist()
authors = blacklist.get("authors", [])
subreddits = blacklist.get("subreddits", [])
if not authors and not subreddits:
return {"deleted": 0, "message": "No authors or subreddits in blacklist"}
db = Database()
posts = []
if authors:
posts.extend(db.get_posts_by_authors(authors))
if subreddits:
posts.extend(db.get_posts_by_subreddits(subreddits))
seen_ids = set()
unique_posts = []
for post in posts:
if post.id not in seen_ids:
seen_ids.add(post.id)
unique_posts.append(post)
deleted_count = 0
errors = []
for post in unique_posts:
try:
if post.local_path:
_delete_media_files(Path(post.local_path))
with db._get_connection() as conn:
conn.execute("DELETE FROM posts WHERE id = ?", (post.id,))
conn.commit()
deleted_count += 1
except Exception as e:
errors.append(f"{post.id}: {e!s}")
return {
"deleted": deleted_count,
"errors": errors if errors else None,
"message": f"Deleted {deleted_count} files from blacklisted authors/subreddits",
}
@router.get("/api/media/cleanup-preview")
async def preview_media_cleanup(media_type: str = Query(..., description="video or gif")):
"""Preview how many files of a specific type would be deleted."""
if media_type not in ("video", "gif"):
raise HTTPException(status_code=400, detail="media_type must be 'video' or 'gif'")
db = Database()
count = db.get_total_media_count(media_type=media_type)
return {"media_type": media_type, "count": count}
@router.post("/api/media/cleanup-by-type")
async def cleanup_media_by_type(media_type: str = Query(..., description="video or gif")):
"""Delete all media of a specific type (video or gif)."""
if media_type not in ("video", "gif"):
raise HTTPException(status_code=400, detail="media_type must be 'video' or 'gif'")
db = Database()
with db._get_connection() as conn:
cursor = conn.execute(
"SELECT id, local_path FROM posts WHERE media_type = ? AND local_path IS NOT NULL",
(media_type,),
)
posts = cursor.fetchall()
deleted_count = 0
errors = []
for post_id, local_path in posts:
try:
if local_path:
_delete_media_files(Path(local_path))
with db._get_connection() as conn:
conn.execute("DELETE FROM posts WHERE id = ?", (post_id,))
conn.commit()
deleted_count += 1
except Exception as e:
errors.append(f"{post_id}: {e!s}")
return {
"deleted": deleted_count,
"media_type": media_type,
"errors": errors if errors else None,
"message": f"Deleted {deleted_count} {media_type} files",
}
def _delete_media_files(file_path: Path) -> None:
"""Delete media file, its sidecar, and thumbnail."""
if file_path.exists():
file_path.unlink()
sidecar_path = file_path.with_suffix(file_path.suffix + ".json")
if sidecar_path.exists():
sidecar_path.unlink()
thumb_path = THUMBS_DIR / f"{file_path.name}.jpg"
if thumb_path.exists():
thumb_path.unlink()

View file

@ -0,0 +1,392 @@
"""Scheduler and collector API routes."""
import os
import re
import subprocess
from contextlib import suppress
from datetime import datetime
import yaml
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from fastapi import APIRouter, BackgroundTasks, HTTPException, Query
from pydantic import BaseModel
from ...database import Database
from .. import config_manager
from ..deps import (
PROJECT_DIR,
SCHEDULER_CONFIG_PATH,
SCHEDULER_DB_PATH,
TIMEZONE,
collector_lock,
collector_status,
)
router = APIRouter(tags=["scheduler"])
# Scheduler setup
jobstores = {"default": SQLAlchemyJobStore(url=f"sqlite:///{SCHEDULER_DB_PATH}")}
scheduler = BackgroundScheduler(jobstores=jobstores, timezone=TIMEZONE)
# Scheduler state
scheduler_config: dict = {
"enabled": False,
"interval_hours": 6,
"mode": "interval",
}
def load_scheduler_config():
"""Load scheduler configuration from YAML file."""
if SCHEDULER_CONFIG_PATH.exists():
with open(SCHEDULER_CONFIG_PATH) as f:
config = yaml.safe_load(f) or {}
scheduler_config["enabled"] = config.get("enabled", False)
scheduler_config["interval_hours"] = config.get("interval_hours", 6)
scheduler_config["mode"] = config.get("mode", "interval")
scheduler_config["specific_times"] = config.get("specific_times", ["00:00", "06:00", "12:00", "18:00"])
def _save_scheduler_config():
"""Save scheduler configuration to YAML file."""
with open(SCHEDULER_CONFIG_PATH, "w") as f:
yaml.dump(
{
"enabled": scheduler_config["enabled"],
"interval_hours": scheduler_config["interval_hours"],
"mode": scheduler_config["mode"],
"specific_times": scheduler_config.get("specific_times", ["00:00", "06:00", "12:00", "18:00"]),
},
f,
)
def setup_scheduler_job():
"""Setup or update the scheduler job based on current config."""
with suppress(Exception):
scheduler.remove_job("collector_job")
if not scheduler_config["enabled"]:
return
if scheduler_config["mode"] == "interval":
trigger = IntervalTrigger(hours=scheduler_config["interval_hours"])
scheduler.add_job(
_run_collector_scheduled,
trigger=trigger,
id="collector_job",
name="Reddit Collector",
replace_existing=True,
)
else:
times = scheduler_config.get("specific_times", ["00:00", "06:00", "12:00", "18:00"])
if times:
hours = []
for t in times:
try:
h, _m = t.split(":")
hours.append(int(h))
except ValueError:
pass
if hours:
hours_str = ",".join(str(h) for h in sorted(set(hours)))
trigger = CronTrigger(hour=hours_str, minute=0)
scheduler.add_job(
_run_collector_scheduled,
trigger=trigger,
id="collector_job",
name="Reddit Collector",
replace_existing=True,
)
def _run_collector_scheduled():
"""Run collector from scheduler (records to history)."""
with collector_lock:
if collector_status["running"]:
return
collector_status["running"] = True
collector_status["last_run"] = datetime.now().isoformat()
db = Database()
run_id = db.add_scheduler_run(datetime.now())
posts_processed = 0
posts_downloaded = 0
try:
result = subprocess.run(
["python3", "-m", "src.main"],
cwd=PROJECT_DIR,
capture_output=True,
text=True,
timeout=14400,
)
if result.stdout:
processed_match = re.search(r"Posts processed:\s*(\d+)", result.stdout)
downloaded_match = re.search(r"New downloads:\s*(\d+)", result.stdout)
if processed_match:
posts_processed = int(processed_match.group(1))
if downloaded_match:
posts_downloaded = int(downloaded_match.group(1))
if result.returncode == 0:
with collector_lock:
collector_status["last_result"] = "success"
db.finish_scheduler_run(run_id, "success", posts_processed, posts_downloaded)
else:
with collector_lock:
collector_status["last_result"] = "error"
db.finish_scheduler_run(
run_id, "error", posts_processed, posts_downloaded, result.stderr[:500] if result.stderr else None
)
except subprocess.TimeoutExpired:
with collector_lock:
collector_status["last_result"] = "timeout"
db.finish_scheduler_run(
run_id, "timeout", posts_processed, posts_downloaded, "Execution timed out after 4 hours"
)
except Exception as e:
with collector_lock:
collector_status["last_result"] = f"error: {e!s}"
db.finish_scheduler_run(run_id, "error", posts_processed, posts_downloaded, str(e)[:500])
finally:
with collector_lock:
collector_status["running"] = False
def _run_collector():
"""Run the collector in background."""
with collector_lock:
collector_status["running"] = True
collector_status["last_run"] = datetime.now().isoformat()
try:
result = subprocess.run(
["python3", "-m", "src.main"],
cwd=PROJECT_DIR,
capture_output=True,
text=True,
timeout=14400,
)
with collector_lock:
collector_status["last_result"] = "success" if result.returncode == 0 else "error"
except subprocess.TimeoutExpired:
with collector_lock:
collector_status["last_result"] = "timeout"
except Exception as e:
with collector_lock:
collector_status["last_result"] = f"error: {e!s}"
finally:
with collector_lock:
collector_status["running"] = False
def _run_individual_collection(target_type: str, target_name: str, media_types: list[str], limit: int):
"""Run collection for a single user or subreddit."""
import tempfile
with collector_lock:
if collector_status["running"]:
return
collector_status["running"] = True
collector_status["last_run"] = datetime.now().isoformat()
db = Database()
run_id = db.add_scheduler_run(datetime.now())
posts_processed = 0
posts_downloaded = 0
try:
base_config_path = PROJECT_DIR / "config.yaml"
with open(base_config_path) as f:
config = yaml.safe_load(f)
if target_type == "user":
config["targets"] = {"subreddits": [], "users": [{"name": target_name, "limit": limit}]}
else:
config["targets"] = {"subreddits": [{"name": target_name, "limit": limit, "sort": "new"}], "users": []}
config["download"]["media_types"] = media_types
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as tmp:
yaml.dump(config, tmp)
tmp_config_path = tmp.name
try:
result = subprocess.run(
["python3", "-m", "src.main", "-c", tmp_config_path],
cwd=PROJECT_DIR,
capture_output=True,
text=True,
timeout=7200,
)
if result.stdout:
processed_match = re.search(r"Posts processed:\s*(\d+)", result.stdout)
downloaded_match = re.search(r"New downloads:\s*(\d+)", result.stdout)
if processed_match:
posts_processed = int(processed_match.group(1))
if downloaded_match:
posts_downloaded = int(downloaded_match.group(1))
if result.returncode == 0:
with collector_lock:
collector_status["last_result"] = "success"
db.finish_scheduler_run(run_id, "success", posts_processed, posts_downloaded)
else:
with collector_lock:
collector_status["last_result"] = "error"
db.finish_scheduler_run(
run_id, "error", posts_processed, posts_downloaded, result.stderr[:500] if result.stderr else None
)
finally:
os.unlink(tmp_config_path)
except subprocess.TimeoutExpired:
with collector_lock:
collector_status["last_result"] = "timeout"
db.finish_scheduler_run(run_id, "timeout", posts_processed, posts_downloaded, "Execution timed out")
except Exception as e:
with collector_lock:
collector_status["last_result"] = f"error: {e!s}"
db.finish_scheduler_run(run_id, "error", posts_processed, posts_downloaded, str(e)[:500])
finally:
with collector_lock:
collector_status["running"] = False
# API Endpoints
@router.get("/api/collector/status")
async def get_collector_status():
"""Get collector status."""
return collector_status
@router.post("/api/collector/run")
async def trigger_collector(background_tasks: BackgroundTasks):
"""Trigger the collector to run."""
if collector_status["running"]:
raise HTTPException(status_code=409, detail="Collector is already running")
background_tasks.add_task(_run_collector)
return {"message": "Collector started"}
class SchedulerConfigUpdate(BaseModel):
enabled: bool
interval_hours: int = 6
mode: str = "interval"
specific_times: list[str] = ["00:00", "06:00", "12:00", "18:00"]
@router.get("/api/scheduler/status")
async def get_scheduler_status():
"""Get scheduler status including next run time."""
job = scheduler.get_job("collector_job")
next_run = None
if job and job.next_run_time:
next_run = job.next_run_time.isoformat()
db = Database()
last_run = db.get_last_scheduler_run()
return {
"enabled": scheduler_config["enabled"],
"interval_hours": scheduler_config["interval_hours"],
"mode": scheduler_config["mode"],
"specific_times": scheduler_config.get("specific_times", ["00:00", "06:00", "12:00", "18:00"]),
"next_run": next_run,
"is_running": collector_status["running"],
"last_run": last_run,
}
@router.put("/api/scheduler/config")
async def update_scheduler_config(config: SchedulerConfigUpdate):
"""Update scheduler configuration."""
scheduler_config["enabled"] = config.enabled
scheduler_config["interval_hours"] = config.interval_hours
scheduler_config["mode"] = config.mode
scheduler_config["specific_times"] = config.specific_times
_save_scheduler_config()
setup_scheduler_job()
return {"message": "Scheduler configuration updated", "config": scheduler_config}
@router.get("/api/scheduler/history")
async def get_scheduler_history(limit: int = Query(default=20, le=100)):
"""Get scheduler run history."""
db = Database()
return db.get_scheduler_history(limit)
@router.post("/api/scheduler/run-now")
async def run_scheduler_now(background_tasks: BackgroundTasks):
"""Trigger an immediate scheduler run."""
if collector_status["running"]:
raise HTTPException(status_code=409, detail="Collector is already running")
background_tasks.add_task(_run_collector_scheduled)
return {"message": "Collector started (scheduled run)"}
class IndividualCollectRequest(BaseModel):
target_type: str
target_name: str
media_types: list[str] = ["image"]
limit: int = 100
@router.post("/api/collect/individual")
async def collect_individual(request: IndividualCollectRequest, background_tasks: BackgroundTasks):
"""Run collection for a single user or subreddit."""
if collector_status["running"]:
raise HTTPException(status_code=409, detail="Collector is already running")
if request.target_type not in ["user", "subreddit"]:
raise HTTPException(status_code=400, detail="target_type must be 'user' or 'subreddit'")
if not request.target_name:
raise HTTPException(status_code=400, detail="target_name is required")
valid_types = ["image", "video", "gif"]
for mt in request.media_types:
if mt not in valid_types:
raise HTTPException(status_code=400, detail=f"Invalid media type: {mt}")
background_tasks.add_task(
_run_individual_collection, request.target_type, request.target_name, request.media_types, request.limit
)
return {
"message": f"Collection started for {request.target_type} '{request.target_name}'",
"media_types": request.media_types,
"limit": request.limit,
}
@router.get("/api/collect/targets")
async def get_collection_targets():
"""Get available targets for individual collection."""
db = Database()
favorite_authors = db.get_favorite_authors()
subreddits = config_manager.get_subreddits()
users = config_manager.get_users()
return {
"favorite_authors": favorite_authors,
"configured_subreddits": [s["name"] for s in subreddits],
"configured_users": [u["name"] for u in users],
}

56
src/web/routers/stats.py Normal file
View file

@ -0,0 +1,56 @@
"""Statistics API routes."""
import shutil
from fastapi import APIRouter, Query
from ...database import Database
from ..deps import DOWNLOADS_DIR
router = APIRouter(tags=["stats"])
@router.get("/api/stats")
async def get_stats():
"""Get collection statistics."""
db = Database()
stats = db.get_stats()
total_size = 0
file_count = 0
if DOWNLOADS_DIR.exists():
for f in DOWNLOADS_DIR.iterdir():
if f.is_file() and not f.name.endswith(".json"):
total_size += f.stat().st_size
file_count += 1
stats["disk_size_bytes"] = total_size
stats["disk_size_mb"] = round(total_size / (1024 * 1024), 2)
stats["disk_size_gb"] = round(total_size / (1024 * 1024 * 1024), 2)
stats["file_count"] = file_count
try:
disk_usage = shutil.disk_usage(DOWNLOADS_DIR)
stats["disk_free_gb"] = round(disk_usage.free / (1024 * 1024 * 1024), 2)
stats["disk_total_gb"] = round(disk_usage.total / (1024 * 1024 * 1024), 2)
stats["disk_used_percent"] = round((disk_usage.used / disk_usage.total) * 100, 1)
except OSError:
stats["disk_free_gb"] = 0
stats["disk_total_gb"] = 0
stats["disk_used_percent"] = 0
return stats
@router.get("/api/stats/enhanced")
async def get_enhanced_stats():
"""Get enhanced statistics for dashboard."""
db = Database()
return db.get_enhanced_stats()
@router.get("/api/stats/recent")
async def get_recent_downloads(limit: int = Query(default=10, le=50)):
"""Get recent downloads."""
db = Database()
return db.get_recent_downloads(limit)

0
tests/__init__.py Normal file
View file

96
tests/conftest.py Normal file
View file

@ -0,0 +1,96 @@
"""Shared test fixtures."""
import pytest
from src.config import (
BlacklistConfig,
Config,
DownloadConfig,
LoggingConfig,
RateLimitConfig,
SubredditTarget,
TargetsConfig,
UserTarget,
)
from src.database import Database
@pytest.fixture
def tmp_dir(tmp_path):
"""Provide a temporary directory."""
return tmp_path
@pytest.fixture
def db(tmp_path):
"""Provide a fresh test database."""
db_path = tmp_path / "test_media.db"
return Database(db_path=str(db_path))
@pytest.fixture
def sample_config():
"""Provide a sample configuration."""
return Config(
targets=TargetsConfig(
subreddits=[SubredditTarget(name="pics", limit=25, sort="hot")],
users=[UserTarget(name="testuser", limit=10)],
),
download=DownloadConfig(
output_dir="./test_downloads",
media_types=["image", "video", "gif"],
min_score=10,
skip_nsfw=True,
max_file_size_mb=100,
),
rate_limit=RateLimitConfig(requests_per_minute=10, download_delay_seconds=0.1),
logging=LoggingConfig(level="DEBUG", file=None),
blacklist=BlacklistConfig(
authors=["spammer"],
subreddits=["spam_sub"],
title_keywords=["buy now"],
domains=["malware.com"],
),
)
@pytest.fixture
def config_file(tmp_path):
"""Provide a temporary config file."""
config_content = """
targets:
subreddits:
- name: "pics"
limit: 25
sort: "hot"
users:
- name: "testuser"
limit: 10
download:
output_dir: "./downloads"
media_types:
- "image"
- "video"
min_score: 10
skip_nsfw: true
max_file_size_mb: 100
rate_limit:
requests_per_minute: 10
download_delay_seconds: 2
logging:
level: "INFO"
file: null
blacklist:
authors:
- "spammer"
subreddits: []
title_keywords: []
domains: []
"""
config_path = tmp_path / "config.yaml"
config_path.write_text(config_content)
return str(config_path)

76
tests/test_config.py Normal file
View file

@ -0,0 +1,76 @@
"""Tests for configuration loading and validation."""
import pytest
from src.config import load_config, setup_logging
class TestLoadConfig:
def test_load_valid_config(self, config_file):
config = load_config(config_file)
assert config.targets.subreddits[0].name == "pics"
assert config.targets.subreddits[0].limit == 25
assert config.targets.users[0].name == "testuser"
assert config.download.min_score == 10
assert config.download.skip_nsfw is True
assert config.rate_limit.requests_per_minute == 10
def test_missing_config_raises(self):
with pytest.raises(FileNotFoundError):
load_config("nonexistent.yaml")
def test_empty_targets_raises(self, tmp_path):
config_path = tmp_path / "empty.yaml"
config_path.write_text("targets:\n subreddits: []\n users: []\n")
with pytest.raises(ValueError, match="No targets configured"):
load_config(str(config_path))
def test_blacklist_lowercase(self, tmp_path):
config_path = tmp_path / "config.yaml"
config_path.write_text("""
targets:
subreddits:
- name: "test"
limit: 10
download: {}
rate_limit: {}
logging: {}
blacklist:
authors:
- "SpAmMeR"
subreddits:
- "BadSub"
title_keywords:
- "BUY NOW"
domains:
- "Evil.COM"
""")
config = load_config(str(config_path))
assert config.blacklist.authors == ["spammer"]
assert config.blacklist.subreddits == ["badsub"]
assert config.blacklist.title_keywords == ["buy now"]
assert config.blacklist.domains == ["evil.com"]
def test_default_values(self, tmp_path):
config_path = tmp_path / "minimal.yaml"
config_path.write_text("""
targets:
subreddits:
- name: "test"
""")
config = load_config(str(config_path))
assert config.download.output_dir == "./downloads"
assert config.download.min_score == 10
assert config.rate_limit.requests_per_minute == 10
class TestSetupLogging:
def test_creates_logger(self, sample_config):
logger = setup_logging(sample_config.logging)
assert logger.name == "reddit_collector"
def test_log_level(self, sample_config):
import logging
logger = setup_logging(sample_config.logging)
assert logger.level == logging.DEBUG

216
tests/test_database.py Normal file
View file

@ -0,0 +1,216 @@
"""Tests for database operations."""
from datetime import datetime
from src.database import PostRecord
def _make_post(post_id="test123", **kwargs):
"""Helper to create a PostRecord with defaults."""
defaults = {
"id": post_id,
"subreddit": "pics",
"author": "testuser",
"title": "Test Post",
"url": "https://reddit.com/test",
"media_url": "https://i.redd.it/test.jpg",
"media_type": "image",
"score": 100,
"created_utc": 1700000000.0,
"downloaded_at": None,
"local_path": None,
"file_hash": None,
"permalink": "/r/pics/comments/test123/test_post/",
"source_type": "subreddit",
"flair": "OC",
}
defaults.update(kwargs)
return PostRecord(**defaults)
class TestDatabase:
def test_init_creates_tables(self, db):
with db._get_connection() as conn:
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = {row[0] for row in cursor.fetchall()}
assert "posts" in tables
assert "favorites" in tables
assert "scheduler_history" in tables
def test_add_and_get_post(self, db):
post = _make_post()
db.add_post(post)
retrieved = db.get_post("test123")
assert retrieved is not None
assert retrieved.id == "test123"
assert retrieved.subreddit == "pics"
assert retrieved.author == "testuser"
def test_post_exists(self, db):
assert db.post_exists("test123") is False
db.add_post(_make_post())
assert db.post_exists("test123") is True
def test_mark_downloaded(self, db):
db.add_post(_make_post())
db.mark_downloaded("test123", "/path/to/file.jpg", "abc123hash")
post = db.get_post("test123")
assert post.local_path == "/path/to/file.jpg"
assert post.file_hash == "abc123hash"
assert post.downloaded_at is not None
def test_hash_exists(self, db):
assert db.hash_exists("abc123hash") is None
post = _make_post()
db.add_post(post)
db.mark_downloaded("test123", "/path/to/file.jpg", "abc123hash")
assert db.hash_exists("abc123hash") == "/path/to/file.jpg"
def test_get_stats(self, db):
db.add_post(_make_post("p1"))
db.mark_downloaded("p1", "/path/p1.jpg", "hash1")
db.add_post(_make_post("p2", media_type="video"))
db.mark_downloaded("p2", "/path/p2.mp4", "hash2")
stats = db.get_stats()
assert stats["total_posts"] == 2
assert stats["downloaded"] == 2
def test_get_media_files_sorting(self, db):
db.add_post(_make_post("p1", score=10, created_utc=1000.0))
db.mark_downloaded("p1", "/p1.jpg", "h1")
db.add_post(_make_post("p2", score=500, created_utc=2000.0))
db.mark_downloaded("p2", "/p2.jpg", "h2")
newest = db.get_media_files(sort="newest")
assert newest[0]["id"] == "p2"
oldest = db.get_media_files(sort="oldest")
assert oldest[0]["id"] == "p1"
score_high = db.get_media_files(sort="score_high")
assert score_high[0]["id"] == "p2"
def test_get_media_files_filtering(self, db):
db.add_post(_make_post("p1", subreddit="pics"))
db.mark_downloaded("p1", "/p1.jpg", "h1")
db.add_post(_make_post("p2", subreddit="videos", media_type="video"))
db.mark_downloaded("p2", "/p2.mp4", "h2")
pics_only = db.get_media_files(subreddit="pics")
assert len(pics_only) == 1
assert pics_only[0]["subreddit"] == "pics"
videos_only = db.get_media_files(media_type="video")
assert len(videos_only) == 1
def test_total_media_count(self, db):
db.add_post(_make_post("p1"))
db.mark_downloaded("p1", "/p1.jpg", "h1")
db.add_post(_make_post("p2"))
db.mark_downloaded("p2", "/p2.jpg", "h2")
assert db.get_total_media_count() == 2
assert db.get_total_media_count(subreddit="pics") == 2
assert db.get_total_media_count(subreddit="nonexistent") == 0
class TestFavorites:
def test_add_favorite(self, db):
db.add_post(_make_post())
assert db.add_favorite("test123") is True
assert db.add_favorite("test123") is False # duplicate
def test_remove_favorite(self, db):
db.add_post(_make_post())
db.add_favorite("test123")
assert db.remove_favorite("test123") is True
assert db.remove_favorite("test123") is False
def test_is_favorite(self, db):
db.add_post(_make_post())
assert db.is_favorite("test123") is False
db.add_favorite("test123")
assert db.is_favorite("test123") is True
def test_get_favorites(self, db):
db.add_post(_make_post("p1"))
db.mark_downloaded("p1", "/p1.jpg", "h1")
db.add_post(_make_post("p2"))
db.mark_downloaded("p2", "/p2.jpg", "h2")
db.add_favorite("p1")
favs = db.get_favorites()
assert len(favs) == 1
assert favs[0]["id"] == "p1"
def test_count_favorites(self, db):
db.add_post(_make_post("p1"))
db.add_post(_make_post("p2"))
db.add_favorite("p1")
db.add_favorite("p2")
assert db.count_favorites() == 2
def test_get_favorite_authors(self, db):
db.add_post(_make_post("p1", author="alice"))
db.add_post(_make_post("p2", author="bob"))
db.add_favorite("p1")
authors = db.get_favorite_authors()
assert "alice" in authors
assert "bob" not in authors
class TestSchedulerHistory:
def test_add_and_finish_run(self, db):
run_id = db.add_scheduler_run(datetime.now())
assert run_id > 0
db.finish_scheduler_run(run_id, "success", 10, 5)
history = db.get_scheduler_history()
assert len(history) == 1
assert history[0]["status"] == "success"
assert history[0]["posts_processed"] == 10
def test_get_last_scheduler_run(self, db):
assert db.get_last_scheduler_run() is None
db.add_scheduler_run(datetime.now())
assert db.get_last_scheduler_run() is not None
class TestPostsByAuthorsAndSubreddits:
def test_get_posts_by_authors(self, db):
db.add_post(_make_post("p1", author="alice"))
db.mark_downloaded("p1", "/p1.jpg", "h1")
db.add_post(_make_post("p2", author="bob"))
db.mark_downloaded("p2", "/p2.jpg", "h2")
posts = db.get_posts_by_authors(["alice"])
assert len(posts) == 1
assert posts[0].author == "alice"
def test_case_insensitive_author_search(self, db):
db.add_post(_make_post("p1", author="Alice"))
db.mark_downloaded("p1", "/p1.jpg", "h1")
posts = db.get_posts_by_authors(["alice"])
assert len(posts) == 1
def test_count_posts_by_authors(self, db):
db.add_post(_make_post("p1", author="alice"))
db.mark_downloaded("p1", "/p1.jpg", "h1")
db.add_post(_make_post("p2", author="alice"))
db.mark_downloaded("p2", "/p2.jpg", "h2")
assert db.count_posts_by_authors(["alice"]) == 2
assert db.count_posts_by_authors(["bob"]) == 0
assert db.count_posts_by_authors([]) == 0
def test_get_posts_by_subreddits(self, db):
db.add_post(_make_post("p1", subreddit="pics"))
db.mark_downloaded("p1", "/p1.jpg", "h1")
posts = db.get_posts_by_subreddits(["pics"])
assert len(posts) == 1
assert posts[0].subreddit == "pics"

47
tests/test_downloader.py Normal file
View file

@ -0,0 +1,47 @@
"""Tests for downloader module."""
import hashlib
from src.config import DownloadConfig, RateLimitConfig
from src.downloader import Downloader
class TestDownloader:
def test_compute_hash(self, tmp_path):
config = DownloadConfig(output_dir=str(tmp_path))
rate = RateLimitConfig(download_delay_seconds=0)
downloader = Downloader(config, rate)
test_file = tmp_path / "test.txt"
test_file.write_bytes(b"hello world")
file_hash = downloader.compute_hash(str(test_file))
expected = hashlib.md5(b"hello world").hexdigest()
assert file_hash == expected
def test_get_extension_from_url(self, tmp_path):
config = DownloadConfig(output_dir=str(tmp_path))
rate = RateLimitConfig(download_delay_seconds=0)
downloader = Downloader(config, rate)
assert downloader._get_extension("https://example.com/image.jpg", None) == ".jpg"
assert downloader._get_extension("https://example.com/image.png", None) == ".png"
assert downloader._get_extension("https://example.com/video.mp4", None) == ".mp4"
def test_get_extension_from_content_type(self, tmp_path):
config = DownloadConfig(output_dir=str(tmp_path))
rate = RateLimitConfig(download_delay_seconds=0)
downloader = Downloader(config, rate)
assert downloader._get_extension("https://example.com/blah", "image/jpeg") == ".jpg"
assert downloader._get_extension("https://example.com/blah", "image/png") == ".png"
assert downloader._get_extension("https://example.com/blah", "video/mp4") == ".mp4"
def test_sanitize_name(self, tmp_path):
config = DownloadConfig(output_dir=str(tmp_path))
rate = RateLimitConfig(download_delay_seconds=0)
downloader = Downloader(config, rate)
assert downloader._sanitize_name("normal_name") == "normal_name"
assert downloader._sanitize_name("has spaces!@#") == "has_spaces___"
assert downloader._sanitize_name("with-dash") == "with-dash"

48
tests/test_extractors.py Normal file
View file

@ -0,0 +1,48 @@
"""Tests for URL extractors."""
from src.extractors import extract_media_url
from src.extractors.imgur import extract_imgur_url
class TestExtractMediaUrl:
def test_direct_image_passthrough(self):
url = "https://i.redd.it/test123.jpg"
result_url, result_type = extract_media_url(url, "image")
assert result_url == url
assert result_type == "image"
def test_unknown_domain_passthrough(self):
url = "https://example.com/image.jpg"
result_url, result_type = extract_media_url(url, "image")
assert result_url == url
assert result_type == "image"
class TestImgurExtractor:
def test_direct_imgur_url(self):
url = "https://i.imgur.com/abc123.jpg"
result_url, result_type = extract_imgur_url(url)
assert result_url == url
assert result_type == "image"
def test_gifv_to_mp4(self):
url = "https://i.imgur.com/abc123.gifv"
result_url, result_type = extract_imgur_url(url)
assert result_url == "https://i.imgur.com/abc123.mp4"
assert result_type == "video"
def test_imgur_page_url(self):
url = "https://imgur.com/abc123"
result_url, result_type = extract_imgur_url(url)
assert result_url == "https://i.imgur.com/abc123.jpg"
assert result_type == "image"
def test_imgur_album_unsupported(self):
url = "https://imgur.com/a/abc123"
result_url, _result_type = extract_imgur_url(url)
assert result_url is None
def test_imgur_gallery_unsupported(self):
url = "https://imgur.com/gallery/abc123"
result_url, _result_type = extract_imgur_url(url)
assert result_url is None

87
tests/test_main.py Normal file
View file

@ -0,0 +1,87 @@
"""Tests for main collector logic."""
from src.main import is_domain_blacklisted, should_download_media, should_download_post
from src.reddit_client import Post
def _make_reddit_post(**kwargs):
"""Helper to create a Post with defaults."""
defaults = {
"id": "test123",
"subreddit": "pics",
"author": "testuser",
"title": "Test Post",
"url": "https://i.redd.it/test.jpg",
"score": 100,
"created_utc": 1700000000.0,
"over_18": False,
"is_gallery": False,
"preview": None,
"media_metadata": None,
"permalink": "/r/pics/test123",
"flair": None,
}
defaults.update(kwargs)
return Post(**defaults)
class TestShouldDownloadPost:
def test_normal_post_allowed(self, sample_config):
post = _make_reddit_post()
allowed, reason = should_download_post(post, sample_config)
assert allowed is True
assert reason == ""
def test_nsfw_skipped(self, sample_config):
post = _make_reddit_post(over_18=True)
allowed, reason = should_download_post(post, sample_config)
assert allowed is False
assert reason == "nsfw"
def test_low_score_skipped(self, sample_config):
post = _make_reddit_post(score=1)
allowed, reason = should_download_post(post, sample_config)
assert allowed is False
assert reason == "score"
def test_blacklisted_author(self, sample_config):
post = _make_reddit_post(author="spammer")
allowed, reason = should_download_post(post, sample_config)
assert allowed is False
assert reason == "blacklist_author"
def test_blacklisted_subreddit(self, sample_config):
post = _make_reddit_post(subreddit="spam_sub")
allowed, reason = should_download_post(post, sample_config)
assert allowed is False
assert reason == "blacklist_subreddit"
def test_blacklisted_keyword(self, sample_config):
post = _make_reddit_post(title="Amazing! Buy now for cheap!")
allowed, reason = should_download_post(post, sample_config)
assert allowed is False
assert reason == "blacklist_keyword"
class TestIsDomainBlacklisted:
def test_blacklisted_domain(self):
assert is_domain_blacklisted("https://malware.com/image.jpg", ["malware.com"]) is True
def test_clean_domain(self):
assert is_domain_blacklisted("https://i.redd.it/image.jpg", ["malware.com"]) is False
def test_empty_blacklist(self):
assert is_domain_blacklisted("https://example.com", []) is False
def test_empty_url(self):
assert is_domain_blacklisted("", ["malware.com"]) is False
class TestShouldDownloadMedia:
def test_allowed_type(self, sample_config):
assert should_download_media("image", sample_config) is True
assert should_download_media("video", sample_config) is True
def test_disallowed_type(self, sample_config):
sample_config.download.media_types = ["image"]
assert should_download_media("video", sample_config) is False

View file

@ -0,0 +1,43 @@
"""Tests for Reddit client."""
import time
from src.reddit_client import Post, RateLimiter
class TestRateLimiter:
def test_first_request_no_wait(self):
limiter = RateLimiter(requests_per_minute=60)
start = time.time()
limiter.wait()
elapsed = time.time() - start
assert elapsed < 0.1
def test_respects_rate_limit(self):
limiter = RateLimiter(requests_per_minute=120) # 0.5s interval
limiter.wait()
limiter.last_request = time.time() # simulate request just happened
start = time.time()
limiter.wait()
elapsed = time.time() - start
assert elapsed >= 0.4 # Should wait ~0.5s
class TestPost:
def test_post_creation(self):
post = Post(
id="test",
subreddit="pics",
author="user",
title="Title",
url="https://example.com",
score=100,
created_utc=1700000000.0,
over_18=False,
is_gallery=False,
preview=None,
media_metadata=None,
)
assert post.id == "test"
assert post.permalink is None
assert post.flair is None

147
tests/test_sidecar.py Normal file
View file

@ -0,0 +1,147 @@
"""Tests for sidecar file generation."""
import json
from src.sidecar import generate_filename, write_immich_sidecar
class TestWriteImmichSidecar:
def test_creates_sidecar_file(self, tmp_path):
media_file = tmp_path / "test.jpg"
media_file.write_bytes(b"fake image data")
sidecar_path = write_immich_sidecar(
filepath=str(media_file),
subreddit="pics",
author="testuser",
title="A great photo",
score=150,
created_utc=1700000000.0,
media_type="image",
permalink="/r/pics/comments/abc/a_great_photo/",
flair="OC",
source_type="subreddit",
)
assert sidecar_path.endswith(".json")
with open(sidecar_path) as f:
data = json.load(f)
assert "dateTimeOriginal" in data
assert data["description"] == "A great photo"
assert data["albums"] == ["r/pics"]
assert "reddit" in data["tags"]
assert "pics" in data["tags"]
assert "OC" in data["tags"]
assert data["rating"] == 3 # score 150 -> rating 3
assert data["people"] == ["testuser"]
assert "reddit.com" in data["externalUrl"]
def test_rating_buckets(self, tmp_path):
media_file = tmp_path / "test.jpg"
media_file.write_bytes(b"fake")
test_cases = [
(5, 1),
(10, 2),
(50, 3),
(200, 4),
(1000, 5),
(5000, 5),
]
for score, expected_rating in test_cases:
path = write_immich_sidecar(
filepath=str(media_file),
subreddit="test",
author="user",
title="test",
score=score,
created_utc=1700000000.0,
media_type="image",
)
with open(path) as f:
data = json.load(f)
assert data["rating"] == expected_rating, f"Score {score} should give rating {expected_rating}"
def test_deleted_author_no_people(self, tmp_path):
media_file = tmp_path / "test.jpg"
media_file.write_bytes(b"fake")
write_immich_sidecar(
filepath=str(media_file),
subreddit="test",
author="[deleted]",
title="test",
score=10,
created_utc=1700000000.0,
media_type="image",
)
with open(str(media_file) + ".json") as f:
data = json.load(f)
assert "people" not in data
def test_no_permalink_no_external_url(self, tmp_path):
media_file = tmp_path / "test.jpg"
media_file.write_bytes(b"fake")
write_immich_sidecar(
filepath=str(media_file),
subreddit="test",
author="user",
title="test",
score=10,
created_utc=1700000000.0,
media_type="image",
)
with open(str(media_file) + ".json") as f:
data = json.load(f)
assert "externalUrl" not in data
class TestGenerateFilename:
def test_basic_filename(self):
name = generate_filename(
subreddit="pics",
author="testuser",
created_utc=1700000000.0,
post_id="abc123",
ext=".jpg",
)
assert "pics" in name
assert "testuser" in name
assert "abc123" in name
assert name.endswith(".jpg")
def test_gallery_index(self):
name = generate_filename(
subreddit="pics",
author="testuser",
created_utc=1700000000.0,
post_id="abc123",
ext=".jpg",
gallery_index=3,
)
assert "_3.jpg" in name
def test_deleted_author(self):
name = generate_filename(
subreddit="pics",
author="[deleted]",
created_utc=1700000000.0,
post_id="abc123",
ext=".jpg",
)
assert "unknown" in name
def test_sanitized_names(self):
name = generate_filename(
subreddit="pics/test",
author="user name!@#",
created_utc=1700000000.0,
post_id="abc123",
ext=".jpg",
)
# Should not contain special characters
assert "/" not in name
assert "!" not in name
assert "@" not in name

84
tests/test_web_api.py Normal file
View file

@ -0,0 +1,84 @@
"""Tests for web API endpoints."""
from fastapi.testclient import TestClient
from src.web.app import app
client = TestClient(app)
class TestHealthEndpoints:
def test_root_returns_html(self):
response = client.get("/")
assert response.status_code == 200
assert "text/html" in response.headers["content-type"]
def test_get_config(self):
response = client.get("/api/config")
assert response.status_code == 200
def test_get_stats(self):
response = client.get("/api/stats")
assert response.status_code == 200
data = response.json()
assert "total_posts" in data
def test_get_collector_status(self):
response = client.get("/api/collector/status")
assert response.status_code == 200
data = response.json()
assert "running" in data
def test_get_media_files(self):
response = client.get("/api/media?limit=10")
assert response.status_code == 200
data = response.json()
assert "files" in data
assert "total" in data
def test_get_subreddits(self):
response = client.get("/api/subreddits")
assert response.status_code == 200
def test_get_users(self):
response = client.get("/api/users")
assert response.status_code == 200
def test_get_blacklist(self):
response = client.get("/api/blacklist")
assert response.status_code == 200
def test_get_settings(self):
response = client.get("/api/settings")
assert response.status_code == 200
class TestMediaEndpoints:
def test_file_not_found(self):
response = client.get("/api/media/file/nonexistent.jpg")
assert response.status_code == 404
def test_thumb_not_found(self):
response = client.get("/api/media/thumb/nonexistent.mp4")
assert response.status_code == 404
def test_media_info_not_found(self):
response = client.get("/api/media/nonexistent/info")
assert response.status_code == 404
def test_delete_media_not_found(self):
response = client.delete("/api/media/nonexistent")
assert response.status_code == 404
class TestValidation:
def test_cleanup_invalid_media_type(self):
response = client.get("/api/media/cleanup-preview?media_type=invalid")
assert response.status_code == 400
def test_collect_invalid_target_type(self):
response = client.post(
"/api/collect/individual",
json={"target_type": "invalid", "target_name": "test"},
)
assert response.status_code == 400