chore: checkpoint baseline (routers, tests, pyproject)
This commit is contained in:
parent
a58498a315
commit
986f1dfef4
37 changed files with 2527 additions and 1664 deletions
15
.dockerignore
Normal file
15
.dockerignore
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
__pycache__
|
||||
*.pyc
|
||||
.pytest_cache
|
||||
.ruff_cache
|
||||
.mypy_cache
|
||||
.git
|
||||
.gitignore
|
||||
*.db
|
||||
downloads/
|
||||
data/
|
||||
*.log
|
||||
.env
|
||||
venv/
|
||||
.venv/
|
||||
tests/
|
||||
36
.github/workflows/ci.yml
vendored
Normal file
36
.github/workflows/ci.yml
vendored
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: pip install ruff
|
||||
- run: ruff check src/
|
||||
- run: ruff format --check src/
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.12"
|
||||
- run: pip install -e ".[dev]"
|
||||
- run: pytest tests/ --cov=src --cov-report=term-missing -v
|
||||
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [lint, test]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- run: docker build -t reddit-media-collector .
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -18,3 +18,9 @@ scheduler_config.yaml
|
|||
.thumbs/
|
||||
*.sqlite
|
||||
.DS_Store
|
||||
.coverage
|
||||
.coverage.*
|
||||
htmlcov/
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
.mypy_cache/
|
||||
|
|
|
|||
7
.pre-commit-config.yaml
Normal file
7
.pre-commit-config.yaml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.4.8
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
35
Dockerfile
Normal file
35
Dockerfile
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
FROM python:3.12-slim AS base
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
FROM base AS builder
|
||||
|
||||
COPY pyproject.toml ./
|
||||
COPY src/ ./src/
|
||||
|
||||
RUN pip install --no-cache-dir .
|
||||
|
||||
FROM base AS runtime
|
||||
|
||||
COPY --from=builder /usr/local/lib/python3.12/site-packages /usr/local/lib/python3.12/site-packages
|
||||
COPY --from=builder /usr/local/bin /usr/local/bin
|
||||
|
||||
WORKDIR /app
|
||||
COPY src/ ./src/
|
||||
|
||||
RUN mkdir -p /app/downloads /app/data
|
||||
|
||||
ENV RMC_DOWNLOAD_DIR=/app/downloads
|
||||
ENV RMC_DB_PATH=/app/data/media.db
|
||||
ENV RMC_CONFIG_PATH=/app/config.yaml
|
||||
ENV RMC_TIMEZONE=UTC
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
VOLUME ["/app/downloads", "/app/data", "/app/config.yaml"]
|
||||
|
||||
CMD ["uvicorn", "src.web.app:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
15
docker-compose.yml
Normal file
15
docker-compose.yml
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
services:
|
||||
collector:
|
||||
build: .
|
||||
container_name: reddit-media-collector
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "8000:8000"
|
||||
volumes:
|
||||
- ./downloads:/app/downloads
|
||||
- ./data:/app/data
|
||||
- ./config.yaml:/app/config.yaml:ro
|
||||
environment:
|
||||
- RMC_TIMEZONE=UTC
|
||||
- RMC_DOWNLOAD_DIR=/app/downloads
|
||||
- RMC_DB_PATH=/app/data/media.db
|
||||
86
pyproject.toml
Normal file
86
pyproject.toml
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src"]
|
||||
|
||||
[project]
|
||||
name = "reddit-media-collector"
|
||||
version = "1.0.0"
|
||||
description = "Self-hosted media collector for Reddit with Immich integration"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = "MIT"
|
||||
dependencies = [
|
||||
"requests>=2.31.0",
|
||||
"pyyaml>=6.0",
|
||||
"yt-dlp>=2024.0.0",
|
||||
"fastapi>=0.109.0",
|
||||
"uvicorn>=0.27.0",
|
||||
"jinja2>=3.1.0",
|
||||
"apscheduler>=3.10.0",
|
||||
"sqlalchemy>=2.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"httpx>=0.27.0",
|
||||
"ruff>=0.4.0",
|
||||
"mypy>=1.10.0",
|
||||
"types-requests>=2.31.0",
|
||||
"types-PyYAML>=6.0",
|
||||
"pre-commit>=3.7.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
reddit-collector = "src.main:main"
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 120
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"UP", # pyupgrade
|
||||
"B", # flake8-bugbear
|
||||
"SIM", # flake8-simplify
|
||||
"TCH", # flake8-type-checking
|
||||
"RUF", # ruff-specific rules
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long (handled by formatter)
|
||||
"B008", # do not perform function calls in argument defaults (FastAPI uses this)
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["src"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = false
|
||||
check_untyped_defs = true
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
pythonpath = ["."]
|
||||
addopts = "-v --tb=short"
|
||||
|
||||
[tool.coverage.run]
|
||||
source = ["src"]
|
||||
omit = ["src/web/templates/*"]
|
||||
|
||||
[tool.coverage.report]
|
||||
show_missing = true
|
||||
skip_empty = true
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
requests>=2.31.0
|
||||
pyyaml>=6.0
|
||||
Pillow>=10.0.0
|
||||
yt-dlp>=2024.0.0
|
||||
fastapi>=0.109.0
|
||||
uvicorn>=0.27.0
|
||||
jinja2>=3.1.0
|
||||
|
|
@ -1,9 +1,9 @@
|
|||
"""Configuration loader and validator."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
|
||||
|
|
@ -49,7 +49,7 @@ class RateLimitConfig:
|
|||
@dataclass
|
||||
class LoggingConfig:
|
||||
level: str = "INFO"
|
||||
file: Optional[str] = "collector.log"
|
||||
file: str | None = "collector.log"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -79,7 +79,7 @@ def load_config(config_path: str = "config.yaml") -> Config:
|
|||
"Please copy config.yaml.example to config.yaml and customize."
|
||||
)
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
with open(path, encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
targets_data = data.get("targets", {})
|
||||
|
|
@ -92,10 +92,7 @@ def load_config(config_path: str = "config.yaml") -> Config:
|
|||
)
|
||||
for s in targets_data.get("subreddits", [])
|
||||
]
|
||||
users = [
|
||||
UserTarget(name=u["name"], limit=u.get("limit", 30))
|
||||
for u in targets_data.get("users", [])
|
||||
]
|
||||
users = [UserTarget(name=u["name"], limit=u.get("limit", 30)) for u in targets_data.get("users", [])]
|
||||
targets = TargetsConfig(subreddits=subreddits, users=users)
|
||||
|
||||
if not targets.subreddits and not targets.users:
|
||||
|
|
@ -103,7 +100,7 @@ def load_config(config_path: str = "config.yaml") -> Config:
|
|||
|
||||
download_data = data.get("download", {})
|
||||
download = DownloadConfig(
|
||||
output_dir=download_data.get("output_dir", "./downloads"),
|
||||
output_dir=os.environ.get("RMC_DOWNLOAD_DIR", download_data.get("output_dir", "./downloads")),
|
||||
media_types=download_data.get("media_types", ["image", "video", "gif"]),
|
||||
min_score=download_data.get("min_score", 10),
|
||||
skip_nsfw=download_data.get("skip_nsfw", True),
|
||||
|
|
@ -147,9 +144,7 @@ def setup_logging(config: LoggingConfig) -> logging.Logger:
|
|||
logger = logging.getLogger("reddit_collector")
|
||||
logger.setLevel(getattr(logging, config.level.upper()))
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
|
|
|
|||
355
src/database.py
355
src/database.py
|
|
@ -1,11 +1,13 @@
|
|||
"""SQLite database for storing post metadata and tracking downloads."""
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, suppress
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
DEFAULT_DB_PATH = os.environ.get("RMC_DB_PATH", "media.db")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -15,22 +17,22 @@ class PostRecord:
|
|||
author: str
|
||||
title: str
|
||||
url: str
|
||||
media_url: Optional[str]
|
||||
media_type: Optional[str]
|
||||
media_url: str | None
|
||||
media_type: str | None
|
||||
score: int
|
||||
created_utc: float
|
||||
downloaded_at: Optional[datetime]
|
||||
local_path: Optional[str]
|
||||
file_hash: Optional[str]
|
||||
permalink: Optional[str] = None # Reddit permalink for Immich
|
||||
source_type: Optional[str] = None # 'subreddit' or 'user'
|
||||
flair: Optional[str] = None # Post flair for tagging
|
||||
downloaded_at: datetime | None
|
||||
local_path: str | None
|
||||
file_hash: str | None
|
||||
permalink: str | None = None # Reddit permalink for Immich
|
||||
source_type: str | None = None # 'subreddit' or 'user'
|
||||
flair: str | None = None # Post flair for tagging
|
||||
|
||||
|
||||
class Database:
|
||||
"""SQLite database wrapper for tracking downloaded posts."""
|
||||
|
||||
def __init__(self, db_path: str = "media.db"):
|
||||
def __init__(self, db_path: str = DEFAULT_DB_PATH):
|
||||
self.db_path = Path(db_path)
|
||||
self._init_db()
|
||||
|
||||
|
|
@ -65,27 +67,15 @@ class Database:
|
|||
)
|
||||
""")
|
||||
# Add new columns if they don't exist (migration)
|
||||
try:
|
||||
with suppress(sqlite3.OperationalError):
|
||||
conn.execute("ALTER TABLE posts ADD COLUMN permalink TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
try:
|
||||
with suppress(sqlite3.OperationalError):
|
||||
conn.execute("ALTER TABLE posts ADD COLUMN source_type TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
try:
|
||||
with suppress(sqlite3.OperationalError):
|
||||
conn.execute("ALTER TABLE posts ADD COLUMN flair TEXT")
|
||||
except sqlite3.OperationalError:
|
||||
pass
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_subreddit ON posts(subreddit)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_file_hash ON posts(file_hash)"
|
||||
)
|
||||
conn.execute(
|
||||
"CREATE INDEX IF NOT EXISTS idx_downloaded ON posts(downloaded_at)"
|
||||
)
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_subreddit ON posts(subreddit)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_file_hash ON posts(file_hash)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_downloaded ON posts(downloaded_at)")
|
||||
# Scheduler history table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS scheduler_history (
|
||||
|
|
@ -113,17 +103,13 @@ class Database:
|
|||
def post_exists(self, post_id: str) -> bool:
|
||||
"""Check if a post has already been processed."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT 1 FROM posts WHERE id = ?", (post_id,)
|
||||
)
|
||||
cursor = conn.execute("SELECT 1 FROM posts WHERE id = ?", (post_id,))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def hash_exists(self, file_hash: str) -> Optional[str]:
|
||||
def hash_exists(self, file_hash: str) -> str | None:
|
||||
"""Check if a file with this hash already exists. Returns local_path if found."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT local_path FROM posts WHERE file_hash = ?", (file_hash,)
|
||||
)
|
||||
cursor = conn.execute("SELECT local_path FROM posts WHERE file_hash = ?", (file_hash,))
|
||||
row = cursor.fetchone()
|
||||
return row["local_path"] if row else None
|
||||
|
||||
|
|
@ -158,9 +144,7 @@ class Database:
|
|||
)
|
||||
conn.commit()
|
||||
|
||||
def mark_downloaded(
|
||||
self, post_id: str, local_path: str, file_hash: str
|
||||
) -> None:
|
||||
def mark_downloaded(self, post_id: str, local_path: str, file_hash: str) -> None:
|
||||
"""Mark a post as downloaded with file info."""
|
||||
with self._get_connection() as conn:
|
||||
conn.execute(
|
||||
|
|
@ -182,12 +166,10 @@ class Database:
|
|||
)
|
||||
conn.commit()
|
||||
|
||||
def get_post(self, post_id: str) -> Optional[PostRecord]:
|
||||
def get_post(self, post_id: str) -> PostRecord | None:
|
||||
"""Get a post record by ID."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM posts WHERE id = ?", (post_id,)
|
||||
)
|
||||
cursor = conn.execute("SELECT * FROM posts WHERE id = ?", (post_id,))
|
||||
row = cursor.fetchone()
|
||||
if not row:
|
||||
return None
|
||||
|
|
@ -204,54 +186,49 @@ class Database:
|
|||
downloaded_at=row["downloaded_at"],
|
||||
local_path=row["local_path"],
|
||||
file_hash=row["file_hash"],
|
||||
permalink=row["permalink"] if "permalink" in row.keys() else None,
|
||||
source_type=row["source_type"] if "source_type" in row.keys() else None,
|
||||
flair=row["flair"] if "flair" in row.keys() else None,
|
||||
permalink=row["permalink"],
|
||||
source_type=row["source_type"],
|
||||
flair=row["flair"],
|
||||
)
|
||||
|
||||
def get_all_downloaded(self) -> list[PostRecord]:
|
||||
"""Get all downloaded posts for migration."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM posts WHERE downloaded_at IS NOT NULL"
|
||||
)
|
||||
cursor = conn.execute("SELECT * FROM posts WHERE downloaded_at IS NOT NULL")
|
||||
posts = []
|
||||
for row in cursor.fetchall():
|
||||
posts.append(PostRecord(
|
||||
id=row["id"],
|
||||
subreddit=row["subreddit"],
|
||||
author=row["author"],
|
||||
title=row["title"],
|
||||
url=row["url"],
|
||||
media_url=row["media_url"],
|
||||
media_type=row["media_type"],
|
||||
score=row["score"],
|
||||
created_utc=row["created_utc"],
|
||||
downloaded_at=row["downloaded_at"],
|
||||
local_path=row["local_path"],
|
||||
file_hash=row["file_hash"],
|
||||
permalink=row["permalink"] if "permalink" in row.keys() else None,
|
||||
source_type=row["source_type"] if "source_type" in row.keys() else None,
|
||||
flair=row["flair"] if "flair" in row.keys() else None,
|
||||
))
|
||||
posts.append(
|
||||
PostRecord(
|
||||
id=row["id"],
|
||||
subreddit=row["subreddit"],
|
||||
author=row["author"],
|
||||
title=row["title"],
|
||||
url=row["url"],
|
||||
media_url=row["media_url"],
|
||||
media_type=row["media_type"],
|
||||
score=row["score"],
|
||||
created_utc=row["created_utc"],
|
||||
downloaded_at=row["downloaded_at"],
|
||||
local_path=row["local_path"],
|
||||
file_hash=row["file_hash"],
|
||||
permalink=row["permalink"],
|
||||
source_type=row["source_type"],
|
||||
flair=row["flair"],
|
||||
)
|
||||
)
|
||||
return posts
|
||||
|
||||
def update_local_path(self, post_id: str, new_path: str) -> None:
|
||||
"""Update local_path for a post (used in migration)."""
|
||||
with self._get_connection() as conn:
|
||||
conn.execute(
|
||||
"UPDATE posts SET local_path = ? WHERE id = ?",
|
||||
(new_path, post_id)
|
||||
)
|
||||
conn.execute("UPDATE posts SET local_path = ? WHERE id = ?", (new_path, post_id))
|
||||
conn.commit()
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get collection statistics."""
|
||||
with self._get_connection() as conn:
|
||||
total = conn.execute("SELECT COUNT(*) FROM posts").fetchone()[0]
|
||||
downloaded = conn.execute(
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at IS NOT NULL"
|
||||
).fetchone()[0]
|
||||
downloaded = conn.execute("SELECT COUNT(*) FROM posts WHERE downloaded_at IS NOT NULL").fetchone()[0]
|
||||
|
||||
# Group by source: subreddits as "r/name", users as "u/name"
|
||||
by_source = {}
|
||||
|
|
@ -312,30 +289,30 @@ class Database:
|
|||
month_start = today_start - timedelta(days=30)
|
||||
|
||||
# Format dates for SQLite comparison (space separator, not 'T')
|
||||
today_str = today_start.strftime('%Y-%m-%d %H:%M:%S')
|
||||
week_str = week_start.strftime('%Y-%m-%d %H:%M:%S')
|
||||
month_str = month_start.strftime('%Y-%m-%d %H:%M:%S')
|
||||
today_str = today_start.strftime("%Y-%m-%d %H:%M:%S")
|
||||
week_str = week_start.strftime("%Y-%m-%d %H:%M:%S")
|
||||
month_str = month_start.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Downloads by period
|
||||
downloads_today = conn.execute(
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?",
|
||||
(today_str,)
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?", (today_str,)
|
||||
).fetchone()[0]
|
||||
|
||||
downloads_week = conn.execute(
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?",
|
||||
(week_str,)
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?", (week_str,)
|
||||
).fetchone()[0]
|
||||
|
||||
downloads_month = conn.execute(
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?",
|
||||
(month_str,)
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ?", (month_str,)
|
||||
).fetchone()[0]
|
||||
|
||||
# Average score
|
||||
avg_score = conn.execute(
|
||||
"SELECT AVG(score) FROM posts WHERE downloaded_at IS NOT NULL AND score IS NOT NULL"
|
||||
).fetchone()[0] or 0
|
||||
avg_score = (
|
||||
conn.execute(
|
||||
"SELECT AVG(score) FROM posts WHERE downloaded_at IS NOT NULL AND score IS NOT NULL"
|
||||
).fetchone()[0]
|
||||
or 0
|
||||
)
|
||||
|
||||
# Unique authors
|
||||
unique_authors = conn.execute(
|
||||
|
|
@ -343,9 +320,7 @@ class Database:
|
|||
).fetchone()[0]
|
||||
|
||||
# Favorites count
|
||||
favorites_count = conn.execute(
|
||||
"SELECT COUNT(*) FROM favorites"
|
||||
).fetchone()[0]
|
||||
favorites_count = conn.execute("SELECT COUNT(*) FROM favorites").fetchone()[0]
|
||||
|
||||
# Last download
|
||||
last_download = conn.execute(
|
||||
|
|
@ -358,18 +333,16 @@ class Database:
|
|||
"SELECT downloaded_at FROM posts WHERE downloaded_at IS NOT NULL ORDER BY downloaded_at ASC LIMIT 1"
|
||||
).fetchone()
|
||||
|
||||
total_downloaded = conn.execute(
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at IS NOT NULL"
|
||||
).fetchone()[0]
|
||||
total_downloaded = conn.execute("SELECT COUNT(*) FROM posts WHERE downloaded_at IS NOT NULL").fetchone()[0]
|
||||
|
||||
# Calculate average per day
|
||||
avg_per_day = 0
|
||||
if first_download and first_download[0]:
|
||||
try:
|
||||
first_date = datetime.fromisoformat(first_download[0].replace('Z', '+00:00'))
|
||||
first_date = datetime.fromisoformat(first_download[0].replace("Z", "+00:00"))
|
||||
days_active = max((now - first_date.replace(tzinfo=None)).days, 1)
|
||||
avg_per_day = round(total_downloaded / days_active, 1)
|
||||
except:
|
||||
except (ValueError, TypeError):
|
||||
avg_per_day = 0
|
||||
|
||||
# Top 10 authors
|
||||
|
|
@ -389,16 +362,12 @@ class Database:
|
|||
for i in range(13, -1, -1):
|
||||
day = today_start - timedelta(days=i)
|
||||
day_end = day + timedelta(days=1)
|
||||
day_str = day.strftime('%Y-%m-%d %H:%M:%S')
|
||||
day_end_str = day_end.strftime('%Y-%m-%d %H:%M:%S')
|
||||
day_str = day.strftime("%Y-%m-%d %H:%M:%S")
|
||||
day_end_str = day_end.strftime("%Y-%m-%d %H:%M:%S")
|
||||
count = conn.execute(
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ? AND downloaded_at < ?",
|
||||
(day_str, day_end_str)
|
||||
"SELECT COUNT(*) FROM posts WHERE downloaded_at >= ? AND downloaded_at < ?", (day_str, day_end_str)
|
||||
).fetchone()[0]
|
||||
trend_data.append({
|
||||
"date": day.strftime("%m/%d"),
|
||||
"count": count
|
||||
})
|
||||
trend_data.append({"date": day.strftime("%m/%d"), "count": count})
|
||||
|
||||
return {
|
||||
"downloads_today": downloads_today,
|
||||
|
|
@ -410,7 +379,7 @@ class Database:
|
|||
"last_download": last_download,
|
||||
"avg_per_day": avg_per_day,
|
||||
"top_authors": [{"author": a, "count": c} for a, c in top_authors],
|
||||
"trend": trend_data
|
||||
"trend": trend_data,
|
||||
}
|
||||
|
||||
def get_recent_downloads(self, limit: int = 10) -> list[dict]:
|
||||
|
|
@ -425,13 +394,19 @@ class Database:
|
|||
ORDER BY downloaded_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,)
|
||||
(limit,),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_media_files(self, limit: int = 50, offset: int = 0,
|
||||
subreddit: str = None, media_type: str = None,
|
||||
sort: str = "newest", author: str = None) -> list[dict]:
|
||||
def get_media_files(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
subreddit: str | None = None,
|
||||
media_type: str | None = None,
|
||||
sort: str = "newest",
|
||||
author: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Get media files with optional filtering and sorting.
|
||||
|
||||
Args:
|
||||
|
|
@ -474,7 +449,9 @@ class Database:
|
|||
cursor = conn.execute(query, params)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_total_media_count(self, subreddit: str = None, media_type: str = None, author: str = None) -> int:
|
||||
def get_total_media_count(
|
||||
self, subreddit: str | None = None, media_type: str | None = None, author: str | None = None
|
||||
) -> int:
|
||||
"""Get total count of media files with optional filtering."""
|
||||
with self._get_connection() as conn:
|
||||
query = """
|
||||
|
|
@ -534,7 +511,7 @@ class Database:
|
|||
|
||||
with self._get_connection() as conn:
|
||||
# Create placeholders for IN clause
|
||||
placeholders = ','.join('?' * len(authors))
|
||||
placeholders = ",".join("?" * len(authors))
|
||||
# Convert authors to lowercase for case-insensitive matching
|
||||
authors_lower = [a.lower() for a in authors]
|
||||
|
||||
|
|
@ -545,28 +522,30 @@ class Database:
|
|||
AND local_path IS NOT NULL
|
||||
AND downloaded_at IS NOT NULL
|
||||
""",
|
||||
authors_lower
|
||||
authors_lower,
|
||||
)
|
||||
|
||||
posts = []
|
||||
for row in cursor.fetchall():
|
||||
posts.append(PostRecord(
|
||||
id=row["id"],
|
||||
subreddit=row["subreddit"],
|
||||
author=row["author"],
|
||||
title=row["title"],
|
||||
url=row["url"],
|
||||
media_url=row["media_url"],
|
||||
media_type=row["media_type"],
|
||||
score=row["score"],
|
||||
created_utc=row["created_utc"],
|
||||
downloaded_at=row["downloaded_at"],
|
||||
local_path=row["local_path"],
|
||||
file_hash=row["file_hash"],
|
||||
permalink=row["permalink"] if "permalink" in row.keys() else None,
|
||||
source_type=row["source_type"] if "source_type" in row.keys() else None,
|
||||
flair=row["flair"] if "flair" in row.keys() else None,
|
||||
))
|
||||
posts.append(
|
||||
PostRecord(
|
||||
id=row["id"],
|
||||
subreddit=row["subreddit"],
|
||||
author=row["author"],
|
||||
title=row["title"],
|
||||
url=row["url"],
|
||||
media_url=row["media_url"],
|
||||
media_type=row["media_type"],
|
||||
score=row["score"],
|
||||
created_utc=row["created_utc"],
|
||||
downloaded_at=row["downloaded_at"],
|
||||
local_path=row["local_path"],
|
||||
file_hash=row["file_hash"],
|
||||
permalink=row["permalink"],
|
||||
source_type=row["source_type"],
|
||||
flair=row["flair"],
|
||||
)
|
||||
)
|
||||
return posts
|
||||
|
||||
def count_posts_by_authors(self, authors: list[str]) -> int:
|
||||
|
|
@ -575,7 +554,7 @@ class Database:
|
|||
return 0
|
||||
|
||||
with self._get_connection() as conn:
|
||||
placeholders = ','.join('?' * len(authors))
|
||||
placeholders = ",".join("?" * len(authors))
|
||||
authors_lower = [a.lower() for a in authors]
|
||||
|
||||
cursor = conn.execute(
|
||||
|
|
@ -585,7 +564,7 @@ class Database:
|
|||
AND local_path IS NOT NULL
|
||||
AND downloaded_at IS NOT NULL
|
||||
""",
|
||||
authors_lower
|
||||
authors_lower,
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
|
|
@ -595,7 +574,7 @@ class Database:
|
|||
return []
|
||||
|
||||
with self._get_connection() as conn:
|
||||
placeholders = ','.join('?' * len(subreddits))
|
||||
placeholders = ",".join("?" * len(subreddits))
|
||||
subreddits_lower = [s.lower() for s in subreddits]
|
||||
|
||||
cursor = conn.execute(
|
||||
|
|
@ -605,28 +584,30 @@ class Database:
|
|||
AND local_path IS NOT NULL
|
||||
AND downloaded_at IS NOT NULL
|
||||
""",
|
||||
subreddits_lower
|
||||
subreddits_lower,
|
||||
)
|
||||
|
||||
posts = []
|
||||
for row in cursor.fetchall():
|
||||
posts.append(PostRecord(
|
||||
id=row["id"],
|
||||
subreddit=row["subreddit"],
|
||||
author=row["author"],
|
||||
title=row["title"],
|
||||
url=row["url"],
|
||||
media_url=row["media_url"],
|
||||
media_type=row["media_type"],
|
||||
score=row["score"],
|
||||
created_utc=row["created_utc"],
|
||||
downloaded_at=row["downloaded_at"],
|
||||
local_path=row["local_path"],
|
||||
file_hash=row["file_hash"],
|
||||
permalink=row["permalink"] if "permalink" in row.keys() else None,
|
||||
source_type=row["source_type"] if "source_type" in row.keys() else None,
|
||||
flair=row["flair"] if "flair" in row.keys() else None,
|
||||
))
|
||||
posts.append(
|
||||
PostRecord(
|
||||
id=row["id"],
|
||||
subreddit=row["subreddit"],
|
||||
author=row["author"],
|
||||
title=row["title"],
|
||||
url=row["url"],
|
||||
media_url=row["media_url"],
|
||||
media_type=row["media_type"],
|
||||
score=row["score"],
|
||||
created_utc=row["created_utc"],
|
||||
downloaded_at=row["downloaded_at"],
|
||||
local_path=row["local_path"],
|
||||
file_hash=row["file_hash"],
|
||||
permalink=row["permalink"],
|
||||
source_type=row["source_type"],
|
||||
flair=row["flair"],
|
||||
)
|
||||
)
|
||||
return posts
|
||||
|
||||
def count_posts_by_subreddits(self, subreddits: list[str]) -> int:
|
||||
|
|
@ -635,7 +616,7 @@ class Database:
|
|||
return 0
|
||||
|
||||
with self._get_connection() as conn:
|
||||
placeholders = ','.join('?' * len(subreddits))
|
||||
placeholders = ",".join("?" * len(subreddits))
|
||||
subreddits_lower = [s.lower() for s in subreddits]
|
||||
|
||||
cursor = conn.execute(
|
||||
|
|
@ -645,7 +626,7 @@ class Database:
|
|||
AND local_path IS NOT NULL
|
||||
AND downloaded_at IS NOT NULL
|
||||
""",
|
||||
subreddits_lower
|
||||
subreddits_lower,
|
||||
)
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
|
|
@ -655,10 +636,7 @@ class Database:
|
|||
"""Add a post to favorites. Returns True if added, False if already exists."""
|
||||
with self._get_connection() as conn:
|
||||
try:
|
||||
conn.execute(
|
||||
"INSERT INTO favorites (post_id) VALUES (?)",
|
||||
(post_id,)
|
||||
)
|
||||
conn.execute("INSERT INTO favorites (post_id) VALUES (?)", (post_id,))
|
||||
conn.commit()
|
||||
return True
|
||||
except sqlite3.IntegrityError:
|
||||
|
|
@ -667,20 +645,14 @@ class Database:
|
|||
def remove_favorite(self, post_id: str) -> bool:
|
||||
"""Remove a post from favorites. Returns True if removed."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"DELETE FROM favorites WHERE post_id = ?",
|
||||
(post_id,)
|
||||
)
|
||||
cursor = conn.execute("DELETE FROM favorites WHERE post_id = ?", (post_id,))
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
|
||||
def is_favorite(self, post_id: str) -> bool:
|
||||
"""Check if a post is favorited."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT 1 FROM favorites WHERE post_id = ?",
|
||||
(post_id,)
|
||||
)
|
||||
cursor = conn.execute("SELECT 1 FROM favorites WHERE post_id = ?", (post_id,))
|
||||
return cursor.fetchone() is not None
|
||||
|
||||
def get_favorites(self, limit: int = 50, offset: int = 0) -> list[dict]:
|
||||
|
|
@ -696,7 +668,7 @@ class Database:
|
|||
ORDER BY f.favorited_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
(limit, offset)
|
||||
(limit, offset),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
|
|
@ -727,16 +699,16 @@ class Database:
|
|||
authors: list[str],
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
subreddit: str = None,
|
||||
media_type: str = None,
|
||||
sort: str = "newest"
|
||||
subreddit: str | None = None,
|
||||
media_type: str | None = None,
|
||||
sort: str = "newest",
|
||||
) -> list[dict]:
|
||||
"""Get media files from specific authors."""
|
||||
if not authors:
|
||||
return []
|
||||
|
||||
with self._get_connection() as conn:
|
||||
placeholders = ','.join('?' * len(authors))
|
||||
placeholders = ",".join("?" * len(authors))
|
||||
authors_lower = [a.lower() for a in authors]
|
||||
|
||||
query = f"""
|
||||
|
|
@ -774,17 +746,14 @@ class Database:
|
|||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def count_media_by_authors(
|
||||
self,
|
||||
authors: list[str],
|
||||
subreddit: str = None,
|
||||
media_type: str = None
|
||||
self, authors: list[str], subreddit: str | None = None, media_type: str | None = None
|
||||
) -> int:
|
||||
"""Count media files from specific authors."""
|
||||
if not authors:
|
||||
return 0
|
||||
|
||||
with self._get_connection() as conn:
|
||||
placeholders = ','.join('?' * len(authors))
|
||||
placeholders = ",".join("?" * len(authors))
|
||||
authors_lower = [a.lower() for a in authors]
|
||||
|
||||
query = f"""
|
||||
|
|
@ -807,11 +776,7 @@ class Database:
|
|||
return cursor.fetchone()[0]
|
||||
|
||||
def get_authors_with_stats(
|
||||
self,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
favorites_only: bool = False,
|
||||
sort: str = "count"
|
||||
self, limit: int = 50, offset: int = 0, favorites_only: bool = False, sort: str = "count"
|
||||
) -> list[dict]:
|
||||
"""Get list of authors with their media counts and a sample thumbnail.
|
||||
|
||||
|
|
@ -885,22 +850,27 @@ class Database:
|
|||
for row in rows:
|
||||
author_name = row[0]
|
||||
# Check if author has any favorited posts
|
||||
fav_cursor = conn.execute("""
|
||||
fav_cursor = conn.execute(
|
||||
"""
|
||||
SELECT COUNT(*) FROM favorites f
|
||||
JOIN posts p ON f.post_id = p.id
|
||||
WHERE p.author = ?
|
||||
""", (author_name,))
|
||||
""",
|
||||
(author_name,),
|
||||
)
|
||||
fav_count = fav_cursor.fetchone()[0]
|
||||
|
||||
authors.append({
|
||||
"author": author_name,
|
||||
"media_count": row[1],
|
||||
"max_score": row[2],
|
||||
"total_score": row[3],
|
||||
"latest_post": row[4],
|
||||
"thumb_path": row[5],
|
||||
"is_favorite": fav_count > 0
|
||||
})
|
||||
authors.append(
|
||||
{
|
||||
"author": author_name,
|
||||
"media_count": row[1],
|
||||
"max_score": row[2],
|
||||
"total_score": row[3],
|
||||
"latest_post": row[4],
|
||||
"thumb_path": row[5],
|
||||
"is_favorite": fav_count > 0,
|
||||
}
|
||||
)
|
||||
|
||||
return authors
|
||||
|
||||
|
|
@ -939,8 +909,7 @@ class Database:
|
|||
"""Start a new scheduler run. Returns the run ID."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"INSERT INTO scheduler_history (started_at, status) VALUES (?, 'running')",
|
||||
(started_at,)
|
||||
"INSERT INTO scheduler_history (started_at, status) VALUES (?, 'running')", (started_at,)
|
||||
)
|
||||
conn.commit()
|
||||
return cursor.lastrowid
|
||||
|
|
@ -951,7 +920,7 @@ class Database:
|
|||
status: str,
|
||||
posts_processed: int = 0,
|
||||
posts_downloaded: int = 0,
|
||||
error_message: str = None
|
||||
error_message: str | None = None,
|
||||
) -> None:
|
||||
"""Finish a scheduler run with results."""
|
||||
with self._get_connection() as conn:
|
||||
|
|
@ -962,7 +931,7 @@ class Database:
|
|||
posts_downloaded = ?, error_message = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(datetime.now(), status, posts_processed, posts_downloaded, error_message, run_id)
|
||||
(datetime.now(), status, posts_processed, posts_downloaded, error_message, run_id),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
|
@ -977,11 +946,11 @@ class Database:
|
|||
ORDER BY started_at DESC
|
||||
LIMIT ?
|
||||
""",
|
||||
(limit,)
|
||||
(limit,),
|
||||
)
|
||||
return [dict(row) for row in cursor.fetchall()]
|
||||
|
||||
def get_last_scheduler_run(self) -> Optional[dict]:
|
||||
def get_last_scheduler_run(self) -> dict | None:
|
||||
"""Get the most recent scheduler run."""
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import mimetypes
|
|||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
|
@ -20,6 +19,7 @@ logger = logging.getLogger("reddit_collector")
|
|||
@dataclass
|
||||
class DownloadMetadata:
|
||||
"""Metadata needed for downloading and sidecar generation."""
|
||||
|
||||
subreddit: str
|
||||
author: str
|
||||
title: str
|
||||
|
|
@ -27,10 +27,11 @@ class DownloadMetadata:
|
|||
created_utc: float
|
||||
post_id: str
|
||||
media_type: str
|
||||
gallery_index: Optional[int] = None
|
||||
permalink: Optional[str] = None
|
||||
flair: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
gallery_index: int | None = None
|
||||
permalink: str | None = None
|
||||
flair: str | None = None
|
||||
source_type: str | None = None
|
||||
|
||||
|
||||
MIME_TO_EXT = {
|
||||
"image/jpeg": ".jpg",
|
||||
|
|
@ -60,17 +61,19 @@ class Downloader:
|
|||
def _create_session(self) -> requests.Session:
|
||||
"""Create a requests session with appropriate headers."""
|
||||
session = requests.Session()
|
||||
session.headers.update({
|
||||
"User-Agent": "Mozilla/5.0 (compatible; RedditMediaCollector/1.0)",
|
||||
"Accept": "image/*,video/*,*/*",
|
||||
})
|
||||
session.headers.update(
|
||||
{
|
||||
"User-Agent": "Mozilla/5.0 (compatible; RedditMediaCollector/1.0)",
|
||||
"Accept": "image/*,video/*,*/*",
|
||||
}
|
||||
)
|
||||
return session
|
||||
|
||||
def download(
|
||||
self,
|
||||
url: str,
|
||||
metadata: DownloadMetadata,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
Download media from URL.
|
||||
Returns (local_path, file_hash) or (None, None) on failure.
|
||||
|
|
@ -117,10 +120,9 @@ class Downloader:
|
|||
except requests.exceptions.RequestException as e:
|
||||
last_error = e
|
||||
if attempt < max_retries - 1:
|
||||
wait_time = 2 ** attempt
|
||||
wait_time = 2**attempt
|
||||
logger.warning(
|
||||
f"Download failed (attempt {attempt + 1}/{max_retries}), "
|
||||
f"retrying in {wait_time}s: {e}"
|
||||
f"Download failed (attempt {attempt + 1}/{max_retries}), retrying in {wait_time}s: {e}"
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
|
|
@ -137,9 +139,7 @@ class Downloader:
|
|||
|
||||
content_length = response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > self.max_size:
|
||||
raise ValueError(
|
||||
f"File too large: {int(content_length) / 1024 / 1024:.1f}MB"
|
||||
)
|
||||
raise ValueError(f"File too large: {int(content_length) / 1024 / 1024:.1f}MB")
|
||||
|
||||
ext = self._get_extension(url, response.headers.get("Content-Type"))
|
||||
|
||||
|
|
@ -184,9 +184,7 @@ class Downloader:
|
|||
|
||||
return str(local_path), file_hash
|
||||
|
||||
def _get_extension(
|
||||
self, url: str, content_type: Optional[str]
|
||||
) -> str:
|
||||
def _get_extension(self, url: str, content_type: str | None) -> str:
|
||||
"""Determine file extension from URL or Content-Type."""
|
||||
if content_type:
|
||||
content_type = content_type.split(";")[0].strip()
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
"""Media URL extractors for various hosts."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from .gfycat import extract_gfycat_url
|
||||
from .imgur import extract_imgur_url
|
||||
from .reddit import extract_reddit_video_url
|
||||
from .gfycat import extract_gfycat_url
|
||||
|
||||
logger = logging.getLogger("reddit_collector")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
"""Gfycat/Redgifs URL extractor using yt-dlp."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger("reddit_collector")
|
||||
|
||||
|
||||
def extract_gfycat_url(url: str) -> Optional[str]:
|
||||
def extract_gfycat_url(url: str) -> str | None:
|
||||
"""
|
||||
Extract video URL from Gfycat/Redgifs links.
|
||||
Uses yt-dlp for extraction.
|
||||
|
|
@ -26,10 +25,7 @@ def extract_gfycat_url(url: str) -> Optional[str]:
|
|||
return info["url"]
|
||||
|
||||
if info and "formats" in info:
|
||||
mp4_formats = [
|
||||
f for f in info["formats"]
|
||||
if f.get("ext") == "mp4" and f.get("url")
|
||||
]
|
||||
mp4_formats = [f for f in info["formats"] if f.get("ext") == "mp4" and f.get("url")]
|
||||
if mp4_formats:
|
||||
best = max(mp4_formats, key=lambda x: x.get("height", 0))
|
||||
return best["url"]
|
||||
|
|
|
|||
|
|
@ -2,13 +2,12 @@
|
|||
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
logger = logging.getLogger("reddit_collector")
|
||||
|
||||
|
||||
def extract_imgur_url(url: str) -> tuple[Optional[str], str]:
|
||||
def extract_imgur_url(url: str) -> tuple[str | None, str]:
|
||||
"""
|
||||
Extract direct image/video URL from Imgur links.
|
||||
Returns (url, media_type) or (None, "image") on failure.
|
||||
|
|
|
|||
|
|
@ -1,12 +1,11 @@
|
|||
"""Reddit video (v.redd.it) URL extractor using yt-dlp."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger("reddit_collector")
|
||||
|
||||
|
||||
def extract_reddit_video_url(url: str) -> Optional[str]:
|
||||
def extract_reddit_video_url(url: str) -> str | None:
|
||||
"""
|
||||
Extract video URL from v.redd.it links.
|
||||
Uses yt-dlp to get the actual video URL.
|
||||
|
|
@ -27,10 +26,7 @@ def extract_reddit_video_url(url: str) -> Optional[str]:
|
|||
return info["url"]
|
||||
|
||||
if info and "formats" in info:
|
||||
mp4_formats = [
|
||||
f for f in info["formats"]
|
||||
if f.get("ext") == "mp4" and f.get("url")
|
||||
]
|
||||
mp4_formats = [f for f in info["formats"] if f.get("ext") == "mp4" and f.get("url")]
|
||||
if mp4_formats:
|
||||
best = max(mp4_formats, key=lambda x: x.get("height", 0))
|
||||
return best["url"]
|
||||
|
|
|
|||
50
src/main.py
50
src/main.py
|
|
@ -57,13 +57,10 @@ def is_domain_blacklisted(url: str, blacklist_domains: list[str]) -> bool:
|
|||
return False
|
||||
|
||||
url_lower = url.lower()
|
||||
for domain in blacklist_domains:
|
||||
if domain in url_lower:
|
||||
return True
|
||||
return False
|
||||
return any(domain in url_lower for domain in blacklist_domains)
|
||||
|
||||
|
||||
def should_download_media(media_type: str, config: Config, author: str = None, db: Database = None) -> bool:
|
||||
def should_download_media(media_type: str, config: Config, author: str | None = None, db: Database = None) -> bool:
|
||||
"""Check if media type is allowed. For videos, optionally check if author is favorited."""
|
||||
if media_type not in config.download.media_types:
|
||||
return False
|
||||
|
|
@ -116,10 +113,7 @@ def process_post(
|
|||
# Process each media item
|
||||
for idx, (media_url, media_type) in enumerate(media_urls):
|
||||
# Generate unique ID for gallery items
|
||||
if len(media_urls) > 1:
|
||||
item_id = f"{post.id}_{idx + 1}"
|
||||
else:
|
||||
item_id = post.id
|
||||
item_id = f"{post.id}_{idx + 1}" if len(media_urls) > 1 else post.id
|
||||
|
||||
# Skip if already in database
|
||||
if db.post_exists(item_id):
|
||||
|
|
@ -178,12 +172,12 @@ def process_post(
|
|||
# Correct media_type based on actual file extension
|
||||
ext = os.path.splitext(local_path)[1].lower()
|
||||
actual_type = final_type
|
||||
if ext in ('.jpg', '.jpeg', '.png', '.webp'):
|
||||
actual_type = 'image'
|
||||
elif ext == '.gif':
|
||||
actual_type = 'gif'
|
||||
elif ext in ('.mp4', '.webm', '.mov'):
|
||||
actual_type = 'video'
|
||||
if ext in (".jpg", ".jpeg", ".png", ".webp"):
|
||||
actual_type = "image"
|
||||
elif ext == ".gif":
|
||||
actual_type = "gif"
|
||||
elif ext in (".mp4", ".webm", ".mov"):
|
||||
actual_type = "video"
|
||||
|
||||
if actual_type != final_type:
|
||||
logger.debug(f"Corrected media_type for {item_id}: {final_type} -> {actual_type}")
|
||||
|
|
@ -211,14 +205,9 @@ def process_post(
|
|||
|
||||
stats.downloaded += 1
|
||||
if len(media_urls) > 1:
|
||||
logger.info(
|
||||
f"Downloaded: {item_id} from r/{post.subreddit} ({actual_type}) "
|
||||
f"[{idx + 1}/{len(media_urls)}]"
|
||||
)
|
||||
logger.info(f"Downloaded: {item_id} from r/{post.subreddit} ({actual_type}) [{idx + 1}/{len(media_urls)}]")
|
||||
else:
|
||||
logger.info(
|
||||
f"Downloaded: {item_id} from r/{post.subreddit} ({actual_type})"
|
||||
)
|
||||
logger.info(f"Downloaded: {item_id} from r/{post.subreddit} ({actual_type})")
|
||||
|
||||
|
||||
def collect(config: Config, logger) -> CollectionStats:
|
||||
|
|
@ -233,10 +222,7 @@ def collect(config: Config, logger) -> CollectionStats:
|
|||
logger.info(f"Processing subreddit: r/{target.name}")
|
||||
try:
|
||||
for post in client.get_subreddit_posts(target):
|
||||
process_post(
|
||||
post, client, db, downloader, config, stats, logger,
|
||||
source_type="subreddit"
|
||||
)
|
||||
process_post(post, client, db, downloader, config, stats, logger, source_type="subreddit")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing r/{target.name}: {e}")
|
||||
stats.errors += 1
|
||||
|
|
@ -245,10 +231,7 @@ def collect(config: Config, logger) -> CollectionStats:
|
|||
logger.info(f"Processing user: u/{target.name}")
|
||||
try:
|
||||
for post in client.get_user_posts(target):
|
||||
process_post(
|
||||
post, client, db, downloader, config, stats, logger,
|
||||
source_type="user"
|
||||
)
|
||||
process_post(post, client, db, downloader, config, stats, logger, source_type="user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing u/{target.name}: {e}")
|
||||
stats.errors += 1
|
||||
|
|
@ -289,11 +272,10 @@ def print_report(stats: CollectionStats, db: Database, logger) -> None:
|
|||
|
||||
def main():
|
||||
"""Entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Collect images and videos from Reddit"
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Collect images and videos from Reddit")
|
||||
parser.add_argument(
|
||||
"-c", "--config",
|
||||
"-c",
|
||||
"--config",
|
||||
default="config.yaml",
|
||||
help="Path to configuration file (default: config.yaml)",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
|
@ -17,6 +17,7 @@ BASE_URL = "https://www.reddit.com"
|
|||
@dataclass
|
||||
class Post:
|
||||
"""Represents a Reddit post."""
|
||||
|
||||
id: str
|
||||
subreddit: str
|
||||
author: str
|
||||
|
|
@ -26,10 +27,10 @@ class Post:
|
|||
created_utc: float
|
||||
over_18: bool
|
||||
is_gallery: bool
|
||||
preview: Optional[dict]
|
||||
media_metadata: Optional[dict]
|
||||
permalink: Optional[str] = None # Reddit permalink
|
||||
flair: Optional[str] = None # Post flair text
|
||||
preview: dict | None
|
||||
media_metadata: dict | None
|
||||
permalink: str | None = None # Reddit permalink
|
||||
flair: str | None = None # Post flair text
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
|
|
@ -55,14 +56,16 @@ class RedditClient:
|
|||
def __init__(self, rate_config: RateLimitConfig):
|
||||
self.rate_limiter = RateLimiter(rate_config.requests_per_minute)
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36",
|
||||
})
|
||||
self.session.headers.update(
|
||||
{
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
|
||||
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||
"Chrome/120.0.0.0 Safari/537.36",
|
||||
}
|
||||
)
|
||||
logger.info("Reddit client initialized (public JSON API)")
|
||||
|
||||
def _fetch_json(self, url: str, params: dict = None) -> dict:
|
||||
def _fetch_json(self, url: str, params: dict | None = None) -> dict:
|
||||
"""Fetch JSON from Reddit with rate limiting and error handling."""
|
||||
self.rate_limiter.wait()
|
||||
|
||||
|
|
@ -105,14 +108,9 @@ class RedditClient:
|
|||
flair=flair,
|
||||
)
|
||||
|
||||
def get_subreddit_posts(
|
||||
self, target: SubredditTarget
|
||||
) -> Iterator[Post]:
|
||||
def get_subreddit_posts(self, target: SubredditTarget) -> Iterator[Post]:
|
||||
"""Fetch posts from a subreddit."""
|
||||
logger.info(
|
||||
f"Fetching {target.limit} posts from r/{target.name} "
|
||||
f"(sort: {target.sort})"
|
||||
)
|
||||
logger.info(f"Fetching {target.limit} posts from r/{target.name} (sort: {target.sort})")
|
||||
|
||||
url = f"{BASE_URL}/r/{target.name}/{target.sort}.json"
|
||||
params = {"limit": min(target.limit, 100)}
|
||||
|
|
@ -224,7 +222,7 @@ class RedditClient:
|
|||
if not post.media_metadata:
|
||||
return urls
|
||||
|
||||
for item_id, item in post.media_metadata.items():
|
||||
for _item_id, item in post.media_metadata.items():
|
||||
if item.get("status") == "valid" and item.get("e") == "Image":
|
||||
source = item.get("s", {})
|
||||
if "u" in source:
|
||||
|
|
|
|||
|
|
@ -2,9 +2,7 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from datetime import UTC, datetime
|
||||
|
||||
logger = logging.getLogger("reddit_collector")
|
||||
|
||||
|
|
@ -17,9 +15,9 @@ def write_immich_sidecar(
|
|||
score: int,
|
||||
created_utc: float,
|
||||
media_type: str,
|
||||
permalink: Optional[str] = None,
|
||||
flair: Optional[str] = None,
|
||||
source_type: Optional[str] = None,
|
||||
permalink: str | None = None,
|
||||
flair: str | None = None,
|
||||
source_type: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Write a JSON sidecar file for Immich.
|
||||
|
|
@ -45,7 +43,7 @@ def write_immich_sidecar(
|
|||
sidecar_path = f"{filepath}.json"
|
||||
|
||||
# Convert timestamp to ISO format
|
||||
dt = datetime.fromtimestamp(created_utc, tz=timezone.utc)
|
||||
dt = datetime.fromtimestamp(created_utc, tz=UTC)
|
||||
date_iso = dt.isoformat()
|
||||
|
||||
# Build tags list
|
||||
|
|
@ -99,7 +97,7 @@ def generate_filename(
|
|||
created_utc: float,
|
||||
post_id: str,
|
||||
ext: str,
|
||||
gallery_index: Optional[int] = None,
|
||||
gallery_index: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a descriptive filename with metadata embedded.
|
||||
|
|
@ -122,7 +120,7 @@ def generate_filename(
|
|||
Generated filename
|
||||
"""
|
||||
# Convert timestamp to datetime
|
||||
dt = datetime.fromtimestamp(created_utc, tz=timezone.utc)
|
||||
dt = datetime.fromtimestamp(created_utc, tz=UTC)
|
||||
date_str = dt.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Sanitize components
|
||||
|
|
|
|||
1341
src/web/app.py
1341
src/web/app.py
File diff suppressed because it is too large
Load diff
|
|
@ -1,12 +1,12 @@
|
|||
"""Configuration file manager for CRUD operations."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
CONFIG_PATH = Path(__file__).parent.parent.parent / "config.yaml"
|
||||
CONFIG_PATH = Path(os.environ.get("RMC_CONFIG_PATH", str(Path(__file__).parent.parent.parent / "config.yaml")))
|
||||
|
||||
|
||||
def load_config() -> dict[str, Any]:
|
||||
|
|
@ -14,7 +14,7 @@ def load_config() -> dict[str, Any]:
|
|||
if not CONFIG_PATH.exists():
|
||||
return {"targets": {"subreddits": [], "users": []}}
|
||||
|
||||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
with open(CONFIG_PATH, encoding="utf-8") as f:
|
||||
return yaml.safe_load(f) or {}
|
||||
|
||||
|
||||
|
|
@ -44,11 +44,7 @@ def add_subreddit(name: str, limit: int = 100, sort: str = "new") -> bool:
|
|||
if sub["name"].lower() == name.lower():
|
||||
return False
|
||||
|
||||
config["targets"]["subreddits"].append({
|
||||
"name": name,
|
||||
"limit": limit,
|
||||
"sort": sort
|
||||
})
|
||||
config["targets"]["subreddits"].append({"name": name, "limit": limit, "sort": sort})
|
||||
|
||||
save_config(config)
|
||||
return True
|
||||
|
|
@ -60,9 +56,7 @@ def remove_subreddit(name: str) -> bool:
|
|||
subreddits = config.get("targets", {}).get("subreddits", [])
|
||||
|
||||
original_len = len(subreddits)
|
||||
config["targets"]["subreddits"] = [
|
||||
s for s in subreddits if s["name"].lower() != name.lower()
|
||||
]
|
||||
config["targets"]["subreddits"] = [s for s in subreddits if s["name"].lower() != name.lower()]
|
||||
|
||||
if len(config["targets"]["subreddits"]) < original_len:
|
||||
save_config(config)
|
||||
|
|
@ -90,10 +84,7 @@ def add_user(name: str, limit: int = 100) -> bool:
|
|||
if user["name"].lower() == name.lower():
|
||||
return False
|
||||
|
||||
config["targets"]["users"].append({
|
||||
"name": name,
|
||||
"limit": limit
|
||||
})
|
||||
config["targets"]["users"].append({"name": name, "limit": limit})
|
||||
|
||||
save_config(config)
|
||||
return True
|
||||
|
|
@ -105,9 +96,7 @@ def remove_user(name: str) -> bool:
|
|||
users = config.get("targets", {}).get("users", [])
|
||||
|
||||
original_len = len(users)
|
||||
config["targets"]["users"] = [
|
||||
u for u in users if u["name"].lower() != name.lower()
|
||||
]
|
||||
config["targets"]["users"] = [u for u in users if u["name"].lower() != name.lower()]
|
||||
|
||||
if len(config["targets"]["users"]) < original_len:
|
||||
save_config(config)
|
||||
|
|
@ -117,6 +106,7 @@ def remove_user(name: str) -> bool:
|
|||
|
||||
# Blacklist functions
|
||||
|
||||
|
||||
def _ensure_blacklist(config: dict) -> dict:
|
||||
"""Ensure blacklist structure exists in config."""
|
||||
if "blacklist" not in config:
|
||||
|
|
@ -148,8 +138,7 @@ def add_blacklist_author(author: str) -> bool:
|
|||
# Also remove from users collection if present
|
||||
if "targets" in config and "users" in config["targets"]:
|
||||
config["targets"]["users"] = [
|
||||
u for u in config["targets"]["users"]
|
||||
if u.get("name", "").lower() != author.lower()
|
||||
u for u in config["targets"]["users"] if u.get("name", "").lower() != author.lower()
|
||||
]
|
||||
|
||||
save_config(config)
|
||||
|
|
@ -162,9 +151,7 @@ def remove_blacklist_author(author: str) -> bool:
|
|||
config = _ensure_blacklist(config)
|
||||
|
||||
original_len = len(config["blacklist"]["authors"])
|
||||
config["blacklist"]["authors"] = [
|
||||
a for a in config["blacklist"]["authors"] if a.lower() != author.lower()
|
||||
]
|
||||
config["blacklist"]["authors"] = [a for a in config["blacklist"]["authors"] if a.lower() != author.lower()]
|
||||
|
||||
if len(config["blacklist"]["authors"]) < original_len:
|
||||
save_config(config)
|
||||
|
|
@ -186,8 +173,7 @@ def add_blacklist_subreddit(subreddit: str) -> bool:
|
|||
# Also remove from subreddits collection if present
|
||||
if "targets" in config and "subreddits" in config["targets"]:
|
||||
config["targets"]["subreddits"] = [
|
||||
s for s in config["targets"]["subreddits"]
|
||||
if s.get("name", "").lower() != subreddit.lower()
|
||||
s for s in config["targets"]["subreddits"] if s.get("name", "").lower() != subreddit.lower()
|
||||
]
|
||||
|
||||
save_config(config)
|
||||
|
|
@ -200,9 +186,7 @@ def remove_blacklist_subreddit(subreddit: str) -> bool:
|
|||
config = _ensure_blacklist(config)
|
||||
|
||||
original_len = len(config["blacklist"]["subreddits"])
|
||||
config["blacklist"]["subreddits"] = [
|
||||
s for s in config["blacklist"]["subreddits"] if s.lower() != subreddit.lower()
|
||||
]
|
||||
config["blacklist"]["subreddits"] = [s for s in config["blacklist"]["subreddits"] if s.lower() != subreddit.lower()]
|
||||
|
||||
if len(config["blacklist"]["subreddits"]) < original_len:
|
||||
save_config(config)
|
||||
|
|
@ -266,9 +250,7 @@ def remove_blacklist_domain(domain: str) -> bool:
|
|||
domain = domain.lower().replace("https://", "").replace("http://", "").strip("/")
|
||||
|
||||
original_len = len(config["blacklist"]["domains"])
|
||||
config["blacklist"]["domains"] = [
|
||||
d for d in config["blacklist"]["domains"] if d != domain
|
||||
]
|
||||
config["blacklist"]["domains"] = [d for d in config["blacklist"]["domains"] if d != domain]
|
||||
|
||||
if len(config["blacklist"]["domains"]) < original_len:
|
||||
save_config(config)
|
||||
|
|
|
|||
29
src/web/deps.py
Normal file
29
src/web/deps.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
"""Shared dependencies and state for web application."""
|
||||
|
||||
import os
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
# Project root directory
|
||||
PROJECT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
# Downloads directory for serving media files
|
||||
DOWNLOADS_DIR = os.environ.get("RMC_DOWNLOAD_DIR", str(PROJECT_DIR / "downloads"))
|
||||
DOWNLOADS_DIR = Path(DOWNLOADS_DIR)
|
||||
|
||||
# Thumbnails directory
|
||||
THUMBS_DIR = DOWNLOADS_DIR / ".thumbs"
|
||||
THUMBS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Timezone from env var, default to UTC
|
||||
TIMEZONE = os.environ.get("RMC_TIMEZONE", "UTC")
|
||||
|
||||
# Scheduler config file path
|
||||
SCHEDULER_CONFIG_PATH = PROJECT_DIR / "scheduler_config.yaml"
|
||||
|
||||
# Scheduler DB path
|
||||
SCHEDULER_DB_PATH = PROJECT_DIR / "scheduler.db"
|
||||
|
||||
# Collector state (protected by lock for thread safety)
|
||||
collector_lock = threading.Lock()
|
||||
collector_status: dict = {"running": False, "last_run": None, "last_result": None}
|
||||
1
src/web/routers/__init__.py
Normal file
1
src/web/routers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""FastAPI router modules."""
|
||||
246
src/web/routers/config.py
Normal file
246
src/web/routers/config.py
Normal file
|
|
@ -0,0 +1,246 @@
|
|||
"""Configuration and blacklist API routes."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .. import config_manager
|
||||
|
||||
router = APIRouter(tags=["config"])
|
||||
|
||||
|
||||
class SubredditCreate(BaseModel):
|
||||
name: str
|
||||
limit: int = 100
|
||||
sort: str = "new"
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
name: str
|
||||
limit: int = 100
|
||||
|
||||
|
||||
class BlacklistItem(BaseModel):
|
||||
value: str
|
||||
|
||||
|
||||
@router.get("/api/config")
|
||||
async def get_config():
|
||||
"""Get full configuration."""
|
||||
return config_manager.load_config()
|
||||
|
||||
|
||||
@router.get("/api/subreddits")
|
||||
async def list_subreddits():
|
||||
"""List all subreddits."""
|
||||
return config_manager.get_subreddits()
|
||||
|
||||
|
||||
@router.post("/api/subreddits")
|
||||
async def add_subreddit(data: SubredditCreate):
|
||||
"""Add a new subreddit."""
|
||||
if not data.name:
|
||||
raise HTTPException(status_code=400, detail="Name is required")
|
||||
|
||||
success = config_manager.add_subreddit(data.name, data.limit, data.sort)
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Subreddit already exists")
|
||||
|
||||
return {"message": f"Subreddit '{data.name}' added successfully"}
|
||||
|
||||
|
||||
@router.delete("/api/subreddits/{name}")
|
||||
async def delete_subreddit(name: str):
|
||||
"""Remove a subreddit."""
|
||||
success = config_manager.remove_subreddit(name)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Subreddit not found")
|
||||
|
||||
return {"message": f"Subreddit '{name}' removed successfully"}
|
||||
|
||||
|
||||
@router.get("/api/users")
|
||||
async def list_users():
|
||||
"""List all users."""
|
||||
return config_manager.get_users()
|
||||
|
||||
|
||||
@router.post("/api/users")
|
||||
async def add_user(data: UserCreate):
|
||||
"""Add a new user."""
|
||||
if not data.name:
|
||||
raise HTTPException(status_code=400, detail="Name is required")
|
||||
|
||||
success = config_manager.add_user(data.name, data.limit)
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="User already exists")
|
||||
|
||||
return {"message": f"User '{data.name}' added successfully"}
|
||||
|
||||
|
||||
@router.delete("/api/users/{name}")
|
||||
async def delete_user(name: str):
|
||||
"""Remove a user."""
|
||||
success = config_manager.remove_user(name)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
return {"message": f"User '{name}' removed successfully"}
|
||||
|
||||
|
||||
# Blacklist endpoints
|
||||
|
||||
|
||||
@router.get("/api/blacklist")
|
||||
async def get_blacklist():
|
||||
"""Get full blacklist configuration."""
|
||||
return config_manager.get_blacklist()
|
||||
|
||||
|
||||
@router.post("/api/blacklist/authors")
|
||||
async def add_blacklist_author(data: BlacklistItem):
|
||||
"""Add an author to the blacklist."""
|
||||
if not data.value:
|
||||
raise HTTPException(status_code=400, detail="Value is required")
|
||||
|
||||
success = config_manager.add_blacklist_author(data.value)
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Author already blacklisted")
|
||||
|
||||
return {"message": f"Author '{data.value}' added to blacklist"}
|
||||
|
||||
|
||||
@router.delete("/api/blacklist/authors/{author}")
|
||||
async def remove_blacklist_author(author: str):
|
||||
"""Remove an author from the blacklist."""
|
||||
success = config_manager.remove_blacklist_author(author)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Author not found in blacklist")
|
||||
|
||||
return {"message": f"Author '{author}' removed from blacklist"}
|
||||
|
||||
|
||||
@router.post("/api/blacklist/subreddits")
|
||||
async def add_blacklist_subreddit(data: BlacklistItem):
|
||||
"""Add a subreddit to the blacklist."""
|
||||
if not data.value:
|
||||
raise HTTPException(status_code=400, detail="Value is required")
|
||||
|
||||
success = config_manager.add_blacklist_subreddit(data.value)
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Subreddit already blacklisted")
|
||||
|
||||
return {"message": f"Subreddit '{data.value}' added to blacklist"}
|
||||
|
||||
|
||||
@router.delete("/api/blacklist/subreddits/{subreddit}")
|
||||
async def remove_blacklist_subreddit(subreddit: str):
|
||||
"""Remove a subreddit from the blacklist."""
|
||||
success = config_manager.remove_blacklist_subreddit(subreddit)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Subreddit not found in blacklist")
|
||||
|
||||
return {"message": f"Subreddit '{subreddit}' removed from blacklist"}
|
||||
|
||||
|
||||
@router.post("/api/blacklist/keywords")
|
||||
async def add_blacklist_keyword(data: BlacklistItem):
|
||||
"""Add a title keyword to the blacklist."""
|
||||
if not data.value:
|
||||
raise HTTPException(status_code=400, detail="Value is required")
|
||||
|
||||
success = config_manager.add_blacklist_keyword(data.value)
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Keyword already blacklisted")
|
||||
|
||||
return {"message": f"Keyword '{data.value}' added to blacklist"}
|
||||
|
||||
|
||||
@router.delete("/api/blacklist/keywords/{keyword:path}")
|
||||
async def remove_blacklist_keyword(keyword: str):
|
||||
"""Remove a title keyword from the blacklist."""
|
||||
success = config_manager.remove_blacklist_keyword(keyword)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Keyword not found in blacklist")
|
||||
|
||||
return {"message": f"Keyword '{keyword}' removed from blacklist"}
|
||||
|
||||
|
||||
@router.post("/api/blacklist/domains")
|
||||
async def add_blacklist_domain(data: BlacklistItem):
|
||||
"""Add a domain to the blacklist."""
|
||||
if not data.value:
|
||||
raise HTTPException(status_code=400, detail="Value is required")
|
||||
|
||||
success = config_manager.add_blacklist_domain(data.value)
|
||||
if not success:
|
||||
raise HTTPException(status_code=409, detail="Domain already blacklisted")
|
||||
|
||||
return {"message": f"Domain '{data.value}' added to blacklist"}
|
||||
|
||||
|
||||
@router.delete("/api/blacklist/domains/{domain:path}")
|
||||
async def remove_blacklist_domain(domain: str):
|
||||
"""Remove a domain from the blacklist."""
|
||||
success = config_manager.remove_blacklist_domain(domain)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Domain not found in blacklist")
|
||||
|
||||
return {"message": f"Domain '{domain}' removed from blacklist"}
|
||||
|
||||
|
||||
# Settings endpoints
|
||||
|
||||
|
||||
class DownloadSettings(BaseModel):
|
||||
media_types: list[str] = ["image"]
|
||||
min_score: int = 1
|
||||
skip_nsfw: bool = False
|
||||
max_file_size_mb: int = 200
|
||||
videos_only_from_favorites: bool = False
|
||||
|
||||
|
||||
@router.get("/api/settings")
|
||||
async def get_settings():
|
||||
"""Get current settings."""
|
||||
config = config_manager.load_config()
|
||||
return {
|
||||
"download": config.get("download", {}),
|
||||
"rate_limit": config.get("rate_limit", {}),
|
||||
"blacklist": config.get("blacklist", {}),
|
||||
}
|
||||
|
||||
|
||||
@router.put("/api/settings/download")
|
||||
async def update_download_settings(settings: DownloadSettings):
|
||||
"""Update download settings."""
|
||||
config = config_manager.load_config()
|
||||
if "download" not in config:
|
||||
config["download"] = {}
|
||||
|
||||
config["download"]["media_types"] = settings.media_types
|
||||
config["download"]["min_score"] = settings.min_score
|
||||
config["download"]["skip_nsfw"] = settings.skip_nsfw
|
||||
config["download"]["max_file_size_mb"] = settings.max_file_size_mb
|
||||
config["download"]["videos_only_from_favorites"] = settings.videos_only_from_favorites
|
||||
|
||||
config_manager.save_config(config)
|
||||
return {"message": "Download settings updated", "settings": settings.model_dump()}
|
||||
|
||||
|
||||
class RateLimitSettings(BaseModel):
|
||||
requests_per_minute: int = 20
|
||||
download_delay_seconds: float = 2.0
|
||||
|
||||
|
||||
@router.put("/api/settings/rate-limit")
|
||||
async def update_rate_limit_settings(settings: RateLimitSettings):
|
||||
"""Update rate limit settings."""
|
||||
config = config_manager.load_config()
|
||||
if "rate_limit" not in config:
|
||||
config["rate_limit"] = {}
|
||||
|
||||
config["rate_limit"]["requests_per_minute"] = settings.requests_per_minute
|
||||
config["rate_limit"]["download_delay_seconds"] = settings.download_delay_seconds
|
||||
|
||||
config_manager.save_config(config)
|
||||
return {"message": "Rate limit settings updated", "settings": settings.model_dump()}
|
||||
115
src/web/routers/favorites.py
Normal file
115
src/web/routers/favorites.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""Favorites and authors API routes."""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from ...database import Database
|
||||
from .. import config_manager
|
||||
|
||||
router = APIRouter(tags=["favorites"])
|
||||
|
||||
|
||||
@router.get("/api/favorites")
|
||||
async def get_favorites(limit: int = Query(default=50, le=200), offset: int = Query(default=0, ge=0)):
|
||||
"""Get favorited posts."""
|
||||
db = Database()
|
||||
favorites = db.get_favorites(limit, offset)
|
||||
total = db.count_favorites()
|
||||
|
||||
for fav in favorites:
|
||||
fav["is_favorite"] = True
|
||||
|
||||
return {"favorites": favorites, "total": total, "limit": limit, "offset": offset}
|
||||
|
||||
|
||||
@router.post("/api/favorites/{post_id}")
|
||||
async def add_favorite(post_id: str, add_user_to_collection: bool = True):
|
||||
"""Add a post to favorites. Optionally adds the author to user collection."""
|
||||
db = Database()
|
||||
|
||||
post = db.get_post(post_id)
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
|
||||
added = db.add_favorite(post_id)
|
||||
|
||||
result = {
|
||||
"message": f"Post '{post_id}' added to favorites" if added else "Post already in favorites",
|
||||
"added": added,
|
||||
}
|
||||
|
||||
if add_user_to_collection and post.author and post.author not in ("[deleted]", "AutoModerator"):
|
||||
user_added = config_manager.add_user(post.author, limit=100)
|
||||
if user_added:
|
||||
result["user_added"] = post.author
|
||||
result["message"] += f". User '{post.author}' added to collection targets."
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/api/favorites/{post_id}")
|
||||
async def remove_favorite(post_id: str):
|
||||
"""Remove a post from favorites."""
|
||||
db = Database()
|
||||
removed = db.remove_favorite(post_id)
|
||||
|
||||
if not removed:
|
||||
raise HTTPException(status_code=404, detail="Post not in favorites")
|
||||
|
||||
return {"message": f"Post '{post_id}' removed from favorites"}
|
||||
|
||||
|
||||
@router.get("/api/favorites/authors")
|
||||
async def get_favorite_authors():
|
||||
"""Get list of unique authors from favorited posts."""
|
||||
db = Database()
|
||||
return db.get_favorite_authors()
|
||||
|
||||
|
||||
@router.post("/api/favorites/sync-users")
|
||||
async def sync_favorite_authors_to_users():
|
||||
"""Add all authors from favorites to user collection targets."""
|
||||
db = Database()
|
||||
authors = db.get_favorite_authors()
|
||||
|
||||
added = []
|
||||
for author in authors:
|
||||
if config_manager.add_user(author, limit=100):
|
||||
added.append(author)
|
||||
|
||||
return {"synced": len(added), "added_users": added, "message": f"Added {len(added)} users to collection targets"}
|
||||
|
||||
|
||||
# Authors endpoints
|
||||
|
||||
|
||||
@router.get("/api/authors")
|
||||
async def get_authors(
|
||||
limit: int = Query(default=50, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
favorites_only: bool = Query(default=False),
|
||||
sort: str = Query(default="count"),
|
||||
):
|
||||
"""Get list of authors with stats and thumbnails."""
|
||||
db = Database()
|
||||
authors = db.get_authors_with_stats(limit, offset, favorites_only, sort)
|
||||
total = db.count_authors(favorites_only)
|
||||
|
||||
return {"authors": authors, "total": total, "limit": limit, "offset": offset}
|
||||
|
||||
|
||||
@router.get("/api/authors/{author}/media")
|
||||
async def get_author_media(
|
||||
author: str,
|
||||
limit: int = Query(default=50, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
sort: str = Query(default="newest"),
|
||||
):
|
||||
"""Get all media from a specific author."""
|
||||
db = Database()
|
||||
files = db.get_media_by_authors([author], limit, offset, sort=sort)
|
||||
total = db.count_media_by_authors([author])
|
||||
|
||||
for f in files:
|
||||
f["is_favorite"] = db.is_favorite(f["id"])
|
||||
|
||||
return {"files": files, "total": total, "author": author, "limit": limit, "offset": offset}
|
||||
384
src/web/routers/media.py
Normal file
384
src/web/routers/media.py
Normal file
|
|
@ -0,0 +1,384 @@
|
|||
"""Media browsing, serving, and cleanup API routes."""
|
||||
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, Query
|
||||
from fastapi.responses import FileResponse, StreamingResponse
|
||||
|
||||
from ...database import Database
|
||||
from .. import config_manager
|
||||
from ..deps import DOWNLOADS_DIR, THUMBS_DIR
|
||||
|
||||
router = APIRouter(tags=["media"])
|
||||
|
||||
|
||||
@router.get("/api/media")
|
||||
async def get_media_files(
|
||||
limit: int = Query(default=50, le=200),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
subreddit: str | None = None,
|
||||
media_type: str | None = None,
|
||||
author: str | None = None,
|
||||
sort: str = Query(default="newest"),
|
||||
favorites_only: bool = Query(default=False),
|
||||
favorite_authors: bool = Query(default=False),
|
||||
):
|
||||
"""Get media files with pagination, filtering and sorting."""
|
||||
db = Database()
|
||||
|
||||
fav_authors_list = db.get_favorite_authors()
|
||||
fav_authors_set = {a.lower() for a in fav_authors_list} if fav_authors_list else set()
|
||||
|
||||
if favorites_only:
|
||||
files = db.get_favorites(limit, offset)
|
||||
total = db.count_favorites()
|
||||
elif favorite_authors:
|
||||
if fav_authors_list:
|
||||
files = db.get_media_by_authors(fav_authors_list, limit, offset, subreddit, media_type)
|
||||
total = db.count_media_by_authors(fav_authors_list, subreddit, media_type)
|
||||
else:
|
||||
files = []
|
||||
total = 0
|
||||
for f in files:
|
||||
f["is_favorite"] = db.is_favorite(f["id"])
|
||||
else:
|
||||
files = db.get_media_files(limit, offset, subreddit, media_type, sort, author)
|
||||
total = db.get_total_media_count(subreddit, media_type, author)
|
||||
for f in files:
|
||||
f["is_favorite"] = db.is_favorite(f["id"])
|
||||
|
||||
for f in files:
|
||||
file_author = f.get("author", "")
|
||||
f["is_author_favorite"] = file_author.lower() in fav_authors_set if file_author else False
|
||||
|
||||
return {"files": files, "total": total, "limit": limit, "offset": offset}
|
||||
|
||||
|
||||
@router.get("/api/media/subreddits")
|
||||
async def get_media_subreddits():
|
||||
"""Get list of subreddits with downloaded content."""
|
||||
db = Database()
|
||||
return db.get_all_subreddits()
|
||||
|
||||
|
||||
@router.get("/api/media/authors")
|
||||
async def get_media_authors():
|
||||
"""Get list of authors with downloaded content."""
|
||||
db = Database()
|
||||
return db.get_all_authors()
|
||||
|
||||
|
||||
@router.get("/api/media/file/{filename:path}")
|
||||
async def get_media_file(filename: str, range: str | None = Header(None)):
|
||||
"""Serve a media file with Range request support for video streaming."""
|
||||
file_path = DOWNLOADS_DIR / filename
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
file_size = file_path.stat().st_size
|
||||
|
||||
suffix = file_path.suffix.lower()
|
||||
media_types = {
|
||||
".mp4": "video/mp4",
|
||||
".webm": "video/webm",
|
||||
".mov": "video/quicktime",
|
||||
".avi": "video/x-msvideo",
|
||||
".mkv": "video/x-matroska",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
}
|
||||
media_type = media_types.get(suffix, "application/octet-stream")
|
||||
|
||||
if not range or suffix not in (".mp4", ".webm", ".mov", ".avi", ".mkv"):
|
||||
return FileResponse(file_path, media_type=media_type, headers={"Accept-Ranges": "bytes"})
|
||||
|
||||
try:
|
||||
range_str = range.replace("bytes=", "")
|
||||
range_parts = range_str.split("-")
|
||||
start = int(range_parts[0]) if range_parts[0] else 0
|
||||
end = int(range_parts[1]) if range_parts[1] else file_size - 1
|
||||
except (ValueError, IndexError):
|
||||
start = 0
|
||||
end = file_size - 1
|
||||
|
||||
start = max(0, min(start, file_size - 1))
|
||||
end = max(start, min(end, file_size - 1))
|
||||
content_length = end - start + 1
|
||||
|
||||
def iterfile():
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(start)
|
||||
remaining = content_length
|
||||
chunk_size = 64 * 1024
|
||||
while remaining > 0:
|
||||
read_size = min(chunk_size, remaining)
|
||||
data = f.read(read_size)
|
||||
if not data:
|
||||
break
|
||||
remaining -= len(data)
|
||||
yield data
|
||||
|
||||
headers = {
|
||||
"Content-Range": f"bytes {start}-{end}/{file_size}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(content_length),
|
||||
}
|
||||
|
||||
return StreamingResponse(iterfile(), status_code=206, media_type=media_type, headers=headers)
|
||||
|
||||
|
||||
def _generate_thumbnail(video_path: Path) -> Path | None:
|
||||
"""Generate a thumbnail for a video file using ffmpeg."""
|
||||
thumb_path = THUMBS_DIR / f"{video_path.name}.jpg"
|
||||
|
||||
if thumb_path.exists():
|
||||
return thumb_path
|
||||
|
||||
try:
|
||||
for seek_time in ("00:00:01", "00:00:00"):
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(video_path),
|
||||
"-ss",
|
||||
seek_time,
|
||||
"-vframes",
|
||||
"1",
|
||||
"-vf",
|
||||
"scale=320:320:force_original_aspect_ratio=decrease",
|
||||
"-q:v",
|
||||
"2",
|
||||
str(thumb_path),
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
)
|
||||
if result.returncode == 0 and thumb_path.exists():
|
||||
return thumb_path
|
||||
|
||||
return None
|
||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/api/media/thumb/{filename:path}")
|
||||
async def get_video_thumbnail(filename: str):
|
||||
"""Get or generate a thumbnail for a video file."""
|
||||
video_path = DOWNLOADS_DIR / filename
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Video not found")
|
||||
|
||||
video_extensions = {".mp4", ".webm", ".mov", ".avi", ".mkv"}
|
||||
image_extensions = {".jpg", ".jpeg", ".png", ".gif", ".webp"}
|
||||
|
||||
if video_path.suffix.lower() in image_extensions:
|
||||
return FileResponse(video_path, media_type="image/jpeg")
|
||||
|
||||
if video_path.suffix.lower() not in video_extensions:
|
||||
raise HTTPException(status_code=400, detail="Not a video file")
|
||||
|
||||
thumb_path = _generate_thumbnail(video_path)
|
||||
|
||||
if thumb_path and thumb_path.exists():
|
||||
return FileResponse(thumb_path, media_type="image/jpeg")
|
||||
|
||||
# Check magic bytes as fallback
|
||||
try:
|
||||
with open(video_path, "rb") as f:
|
||||
header = f.read(12)
|
||||
if header[:2] == b"\xff\xd8":
|
||||
return FileResponse(video_path, media_type="image/jpeg")
|
||||
if header[:8] == b"\x89PNG\r\n\x1a\n":
|
||||
return FileResponse(video_path, media_type="image/png")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(status_code=500, detail="Failed to generate thumbnail")
|
||||
|
||||
|
||||
@router.delete("/api/media/{post_id}")
|
||||
async def delete_media(post_id: str, blacklist_author: bool = False, blacklist_subreddit: bool = False):
|
||||
"""Delete a media file and its database record."""
|
||||
db = Database()
|
||||
post = db.get_post(post_id)
|
||||
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
|
||||
blacklisted = []
|
||||
if blacklist_author and post.author and post.author not in ("[deleted]", "AutoModerator"):
|
||||
config_manager.add_blacklist_author(post.author)
|
||||
blacklisted.append(f"author:{post.author}")
|
||||
|
||||
if blacklist_subreddit and post.subreddit:
|
||||
config_manager.add_blacklist_subreddit(post.subreddit)
|
||||
blacklisted.append(f"subreddit:{post.subreddit}")
|
||||
|
||||
if post.local_path:
|
||||
_delete_media_files(Path(post.local_path))
|
||||
|
||||
with db._get_connection() as conn:
|
||||
conn.execute("DELETE FROM posts WHERE id = ?", (post_id,))
|
||||
conn.commit()
|
||||
|
||||
result = {"message": f"Media '{post_id}' deleted successfully"}
|
||||
if blacklisted:
|
||||
result["blacklisted"] = blacklisted
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/api/media/{post_id}/info")
|
||||
async def get_media_info(post_id: str):
|
||||
"""Get media info for delete confirmation dialog."""
|
||||
db = Database()
|
||||
post = db.get_post(post_id)
|
||||
|
||||
if not post:
|
||||
raise HTTPException(status_code=404, detail="Post not found")
|
||||
|
||||
return {
|
||||
"id": post.id,
|
||||
"author": post.author,
|
||||
"subreddit": post.subreddit,
|
||||
"title": post.title,
|
||||
"media_type": post.media_type,
|
||||
}
|
||||
|
||||
|
||||
# Cleanup endpoints
|
||||
|
||||
|
||||
@router.get("/api/media/blacklist-preview")
|
||||
async def preview_blacklist_cleanup():
|
||||
"""Preview how many files would be deleted by cleanup."""
|
||||
blacklist = config_manager.get_blacklist()
|
||||
authors = blacklist.get("authors", [])
|
||||
subreddits = blacklist.get("subreddits", [])
|
||||
|
||||
db = Database()
|
||||
author_count = db.count_posts_by_authors(authors) if authors else 0
|
||||
subreddit_count = db.count_posts_by_subreddits(subreddits) if subreddits else 0
|
||||
|
||||
return {
|
||||
"author_count": author_count,
|
||||
"subreddit_count": subreddit_count,
|
||||
"total_count": author_count + subreddit_count,
|
||||
"authors": authors,
|
||||
"subreddits": subreddits,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/api/media/cleanup-blacklist")
|
||||
async def cleanup_blacklisted_media():
|
||||
"""Delete all media from blacklisted authors and subreddits."""
|
||||
blacklist = config_manager.get_blacklist()
|
||||
authors = blacklist.get("authors", [])
|
||||
subreddits = blacklist.get("subreddits", [])
|
||||
|
||||
if not authors and not subreddits:
|
||||
return {"deleted": 0, "message": "No authors or subreddits in blacklist"}
|
||||
|
||||
db = Database()
|
||||
posts = []
|
||||
if authors:
|
||||
posts.extend(db.get_posts_by_authors(authors))
|
||||
if subreddits:
|
||||
posts.extend(db.get_posts_by_subreddits(subreddits))
|
||||
|
||||
seen_ids = set()
|
||||
unique_posts = []
|
||||
for post in posts:
|
||||
if post.id not in seen_ids:
|
||||
seen_ids.add(post.id)
|
||||
unique_posts.append(post)
|
||||
|
||||
deleted_count = 0
|
||||
errors = []
|
||||
|
||||
for post in unique_posts:
|
||||
try:
|
||||
if post.local_path:
|
||||
_delete_media_files(Path(post.local_path))
|
||||
|
||||
with db._get_connection() as conn:
|
||||
conn.execute("DELETE FROM posts WHERE id = ?", (post.id,))
|
||||
conn.commit()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{post.id}: {e!s}")
|
||||
|
||||
return {
|
||||
"deleted": deleted_count,
|
||||
"errors": errors if errors else None,
|
||||
"message": f"Deleted {deleted_count} files from blacklisted authors/subreddits",
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/media/cleanup-preview")
|
||||
async def preview_media_cleanup(media_type: str = Query(..., description="video or gif")):
|
||||
"""Preview how many files of a specific type would be deleted."""
|
||||
if media_type not in ("video", "gif"):
|
||||
raise HTTPException(status_code=400, detail="media_type must be 'video' or 'gif'")
|
||||
|
||||
db = Database()
|
||||
count = db.get_total_media_count(media_type=media_type)
|
||||
return {"media_type": media_type, "count": count}
|
||||
|
||||
|
||||
@router.post("/api/media/cleanup-by-type")
|
||||
async def cleanup_media_by_type(media_type: str = Query(..., description="video or gif")):
|
||||
"""Delete all media of a specific type (video or gif)."""
|
||||
if media_type not in ("video", "gif"):
|
||||
raise HTTPException(status_code=400, detail="media_type must be 'video' or 'gif'")
|
||||
|
||||
db = Database()
|
||||
|
||||
with db._get_connection() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT id, local_path FROM posts WHERE media_type = ? AND local_path IS NOT NULL",
|
||||
(media_type,),
|
||||
)
|
||||
posts = cursor.fetchall()
|
||||
|
||||
deleted_count = 0
|
||||
errors = []
|
||||
|
||||
for post_id, local_path in posts:
|
||||
try:
|
||||
if local_path:
|
||||
_delete_media_files(Path(local_path))
|
||||
|
||||
with db._get_connection() as conn:
|
||||
conn.execute("DELETE FROM posts WHERE id = ?", (post_id,))
|
||||
conn.commit()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"{post_id}: {e!s}")
|
||||
|
||||
return {
|
||||
"deleted": deleted_count,
|
||||
"media_type": media_type,
|
||||
"errors": errors if errors else None,
|
||||
"message": f"Deleted {deleted_count} {media_type} files",
|
||||
}
|
||||
|
||||
|
||||
def _delete_media_files(file_path: Path) -> None:
|
||||
"""Delete media file, its sidecar, and thumbnail."""
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
sidecar_path = file_path.with_suffix(file_path.suffix + ".json")
|
||||
if sidecar_path.exists():
|
||||
sidecar_path.unlink()
|
||||
|
||||
thumb_path = THUMBS_DIR / f"{file_path.name}.jpg"
|
||||
if thumb_path.exists():
|
||||
thumb_path.unlink()
|
||||
392
src/web/routers/scheduler.py
Normal file
392
src/web/routers/scheduler.py
Normal file
|
|
@ -0,0 +1,392 @@
|
|||
"""Scheduler and collector API routes."""
|
||||
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from contextlib import suppress
|
||||
from datetime import datetime
|
||||
|
||||
import yaml
|
||||
from apscheduler.jobstores.sqlalchemy import SQLAlchemyJobStore
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ...database import Database
|
||||
from .. import config_manager
|
||||
from ..deps import (
|
||||
PROJECT_DIR,
|
||||
SCHEDULER_CONFIG_PATH,
|
||||
SCHEDULER_DB_PATH,
|
||||
TIMEZONE,
|
||||
collector_lock,
|
||||
collector_status,
|
||||
)
|
||||
|
||||
router = APIRouter(tags=["scheduler"])
|
||||
|
||||
# Scheduler setup
|
||||
jobstores = {"default": SQLAlchemyJobStore(url=f"sqlite:///{SCHEDULER_DB_PATH}")}
|
||||
scheduler = BackgroundScheduler(jobstores=jobstores, timezone=TIMEZONE)
|
||||
|
||||
# Scheduler state
|
||||
scheduler_config: dict = {
|
||||
"enabled": False,
|
||||
"interval_hours": 6,
|
||||
"mode": "interval",
|
||||
}
|
||||
|
||||
|
||||
def load_scheduler_config():
|
||||
"""Load scheduler configuration from YAML file."""
|
||||
if SCHEDULER_CONFIG_PATH.exists():
|
||||
with open(SCHEDULER_CONFIG_PATH) as f:
|
||||
config = yaml.safe_load(f) or {}
|
||||
scheduler_config["enabled"] = config.get("enabled", False)
|
||||
scheduler_config["interval_hours"] = config.get("interval_hours", 6)
|
||||
scheduler_config["mode"] = config.get("mode", "interval")
|
||||
scheduler_config["specific_times"] = config.get("specific_times", ["00:00", "06:00", "12:00", "18:00"])
|
||||
|
||||
|
||||
def _save_scheduler_config():
|
||||
"""Save scheduler configuration to YAML file."""
|
||||
with open(SCHEDULER_CONFIG_PATH, "w") as f:
|
||||
yaml.dump(
|
||||
{
|
||||
"enabled": scheduler_config["enabled"],
|
||||
"interval_hours": scheduler_config["interval_hours"],
|
||||
"mode": scheduler_config["mode"],
|
||||
"specific_times": scheduler_config.get("specific_times", ["00:00", "06:00", "12:00", "18:00"]),
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
|
||||
def setup_scheduler_job():
|
||||
"""Setup or update the scheduler job based on current config."""
|
||||
with suppress(Exception):
|
||||
scheduler.remove_job("collector_job")
|
||||
|
||||
if not scheduler_config["enabled"]:
|
||||
return
|
||||
|
||||
if scheduler_config["mode"] == "interval":
|
||||
trigger = IntervalTrigger(hours=scheduler_config["interval_hours"])
|
||||
scheduler.add_job(
|
||||
_run_collector_scheduled,
|
||||
trigger=trigger,
|
||||
id="collector_job",
|
||||
name="Reddit Collector",
|
||||
replace_existing=True,
|
||||
)
|
||||
else:
|
||||
times = scheduler_config.get("specific_times", ["00:00", "06:00", "12:00", "18:00"])
|
||||
if times:
|
||||
hours = []
|
||||
for t in times:
|
||||
try:
|
||||
h, _m = t.split(":")
|
||||
hours.append(int(h))
|
||||
except ValueError:
|
||||
pass
|
||||
if hours:
|
||||
hours_str = ",".join(str(h) for h in sorted(set(hours)))
|
||||
trigger = CronTrigger(hour=hours_str, minute=0)
|
||||
scheduler.add_job(
|
||||
_run_collector_scheduled,
|
||||
trigger=trigger,
|
||||
id="collector_job",
|
||||
name="Reddit Collector",
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
|
||||
def _run_collector_scheduled():
|
||||
"""Run collector from scheduler (records to history)."""
|
||||
with collector_lock:
|
||||
if collector_status["running"]:
|
||||
return
|
||||
collector_status["running"] = True
|
||||
collector_status["last_run"] = datetime.now().isoformat()
|
||||
|
||||
db = Database()
|
||||
run_id = db.add_scheduler_run(datetime.now())
|
||||
|
||||
posts_processed = 0
|
||||
posts_downloaded = 0
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python3", "-m", "src.main"],
|
||||
cwd=PROJECT_DIR,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=14400,
|
||||
)
|
||||
|
||||
if result.stdout:
|
||||
processed_match = re.search(r"Posts processed:\s*(\d+)", result.stdout)
|
||||
downloaded_match = re.search(r"New downloads:\s*(\d+)", result.stdout)
|
||||
if processed_match:
|
||||
posts_processed = int(processed_match.group(1))
|
||||
if downloaded_match:
|
||||
posts_downloaded = int(downloaded_match.group(1))
|
||||
|
||||
if result.returncode == 0:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "success"
|
||||
db.finish_scheduler_run(run_id, "success", posts_processed, posts_downloaded)
|
||||
else:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "error"
|
||||
db.finish_scheduler_run(
|
||||
run_id, "error", posts_processed, posts_downloaded, result.stderr[:500] if result.stderr else None
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "timeout"
|
||||
db.finish_scheduler_run(
|
||||
run_id, "timeout", posts_processed, posts_downloaded, "Execution timed out after 4 hours"
|
||||
)
|
||||
except Exception as e:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = f"error: {e!s}"
|
||||
db.finish_scheduler_run(run_id, "error", posts_processed, posts_downloaded, str(e)[:500])
|
||||
finally:
|
||||
with collector_lock:
|
||||
collector_status["running"] = False
|
||||
|
||||
|
||||
def _run_collector():
|
||||
"""Run the collector in background."""
|
||||
with collector_lock:
|
||||
collector_status["running"] = True
|
||||
collector_status["last_run"] = datetime.now().isoformat()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python3", "-m", "src.main"],
|
||||
cwd=PROJECT_DIR,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=14400,
|
||||
)
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "success" if result.returncode == 0 else "error"
|
||||
except subprocess.TimeoutExpired:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "timeout"
|
||||
except Exception as e:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = f"error: {e!s}"
|
||||
finally:
|
||||
with collector_lock:
|
||||
collector_status["running"] = False
|
||||
|
||||
|
||||
def _run_individual_collection(target_type: str, target_name: str, media_types: list[str], limit: int):
|
||||
"""Run collection for a single user or subreddit."""
|
||||
import tempfile
|
||||
|
||||
with collector_lock:
|
||||
if collector_status["running"]:
|
||||
return
|
||||
collector_status["running"] = True
|
||||
collector_status["last_run"] = datetime.now().isoformat()
|
||||
|
||||
db = Database()
|
||||
run_id = db.add_scheduler_run(datetime.now())
|
||||
|
||||
posts_processed = 0
|
||||
posts_downloaded = 0
|
||||
|
||||
try:
|
||||
base_config_path = PROJECT_DIR / "config.yaml"
|
||||
with open(base_config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
if target_type == "user":
|
||||
config["targets"] = {"subreddits": [], "users": [{"name": target_name, "limit": limit}]}
|
||||
else:
|
||||
config["targets"] = {"subreddits": [{"name": target_name, "limit": limit, "sort": "new"}], "users": []}
|
||||
|
||||
config["download"]["media_types"] = media_types
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as tmp:
|
||||
yaml.dump(config, tmp)
|
||||
tmp_config_path = tmp.name
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["python3", "-m", "src.main", "-c", tmp_config_path],
|
||||
cwd=PROJECT_DIR,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=7200,
|
||||
)
|
||||
|
||||
if result.stdout:
|
||||
processed_match = re.search(r"Posts processed:\s*(\d+)", result.stdout)
|
||||
downloaded_match = re.search(r"New downloads:\s*(\d+)", result.stdout)
|
||||
if processed_match:
|
||||
posts_processed = int(processed_match.group(1))
|
||||
if downloaded_match:
|
||||
posts_downloaded = int(downloaded_match.group(1))
|
||||
|
||||
if result.returncode == 0:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "success"
|
||||
db.finish_scheduler_run(run_id, "success", posts_processed, posts_downloaded)
|
||||
else:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "error"
|
||||
db.finish_scheduler_run(
|
||||
run_id, "error", posts_processed, posts_downloaded, result.stderr[:500] if result.stderr else None
|
||||
)
|
||||
finally:
|
||||
os.unlink(tmp_config_path)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = "timeout"
|
||||
db.finish_scheduler_run(run_id, "timeout", posts_processed, posts_downloaded, "Execution timed out")
|
||||
except Exception as e:
|
||||
with collector_lock:
|
||||
collector_status["last_result"] = f"error: {e!s}"
|
||||
db.finish_scheduler_run(run_id, "error", posts_processed, posts_downloaded, str(e)[:500])
|
||||
finally:
|
||||
with collector_lock:
|
||||
collector_status["running"] = False
|
||||
|
||||
|
||||
# API Endpoints
|
||||
|
||||
|
||||
@router.get("/api/collector/status")
|
||||
async def get_collector_status():
|
||||
"""Get collector status."""
|
||||
return collector_status
|
||||
|
||||
|
||||
@router.post("/api/collector/run")
|
||||
async def trigger_collector(background_tasks: BackgroundTasks):
|
||||
"""Trigger the collector to run."""
|
||||
if collector_status["running"]:
|
||||
raise HTTPException(status_code=409, detail="Collector is already running")
|
||||
|
||||
background_tasks.add_task(_run_collector)
|
||||
return {"message": "Collector started"}
|
||||
|
||||
|
||||
class SchedulerConfigUpdate(BaseModel):
|
||||
enabled: bool
|
||||
interval_hours: int = 6
|
||||
mode: str = "interval"
|
||||
specific_times: list[str] = ["00:00", "06:00", "12:00", "18:00"]
|
||||
|
||||
|
||||
@router.get("/api/scheduler/status")
|
||||
async def get_scheduler_status():
|
||||
"""Get scheduler status including next run time."""
|
||||
job = scheduler.get_job("collector_job")
|
||||
next_run = None
|
||||
if job and job.next_run_time:
|
||||
next_run = job.next_run_time.isoformat()
|
||||
|
||||
db = Database()
|
||||
last_run = db.get_last_scheduler_run()
|
||||
|
||||
return {
|
||||
"enabled": scheduler_config["enabled"],
|
||||
"interval_hours": scheduler_config["interval_hours"],
|
||||
"mode": scheduler_config["mode"],
|
||||
"specific_times": scheduler_config.get("specific_times", ["00:00", "06:00", "12:00", "18:00"]),
|
||||
"next_run": next_run,
|
||||
"is_running": collector_status["running"],
|
||||
"last_run": last_run,
|
||||
}
|
||||
|
||||
|
||||
@router.put("/api/scheduler/config")
|
||||
async def update_scheduler_config(config: SchedulerConfigUpdate):
|
||||
"""Update scheduler configuration."""
|
||||
scheduler_config["enabled"] = config.enabled
|
||||
scheduler_config["interval_hours"] = config.interval_hours
|
||||
scheduler_config["mode"] = config.mode
|
||||
scheduler_config["specific_times"] = config.specific_times
|
||||
|
||||
_save_scheduler_config()
|
||||
setup_scheduler_job()
|
||||
|
||||
return {"message": "Scheduler configuration updated", "config": scheduler_config}
|
||||
|
||||
|
||||
@router.get("/api/scheduler/history")
|
||||
async def get_scheduler_history(limit: int = Query(default=20, le=100)):
|
||||
"""Get scheduler run history."""
|
||||
db = Database()
|
||||
return db.get_scheduler_history(limit)
|
||||
|
||||
|
||||
@router.post("/api/scheduler/run-now")
|
||||
async def run_scheduler_now(background_tasks: BackgroundTasks):
|
||||
"""Trigger an immediate scheduler run."""
|
||||
if collector_status["running"]:
|
||||
raise HTTPException(status_code=409, detail="Collector is already running")
|
||||
|
||||
background_tasks.add_task(_run_collector_scheduled)
|
||||
return {"message": "Collector started (scheduled run)"}
|
||||
|
||||
|
||||
class IndividualCollectRequest(BaseModel):
|
||||
target_type: str
|
||||
target_name: str
|
||||
media_types: list[str] = ["image"]
|
||||
limit: int = 100
|
||||
|
||||
|
||||
@router.post("/api/collect/individual")
|
||||
async def collect_individual(request: IndividualCollectRequest, background_tasks: BackgroundTasks):
|
||||
"""Run collection for a single user or subreddit."""
|
||||
if collector_status["running"]:
|
||||
raise HTTPException(status_code=409, detail="Collector is already running")
|
||||
|
||||
if request.target_type not in ["user", "subreddit"]:
|
||||
raise HTTPException(status_code=400, detail="target_type must be 'user' or 'subreddit'")
|
||||
|
||||
if not request.target_name:
|
||||
raise HTTPException(status_code=400, detail="target_name is required")
|
||||
|
||||
valid_types = ["image", "video", "gif"]
|
||||
for mt in request.media_types:
|
||||
if mt not in valid_types:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid media type: {mt}")
|
||||
|
||||
background_tasks.add_task(
|
||||
_run_individual_collection, request.target_type, request.target_name, request.media_types, request.limit
|
||||
)
|
||||
|
||||
return {
|
||||
"message": f"Collection started for {request.target_type} '{request.target_name}'",
|
||||
"media_types": request.media_types,
|
||||
"limit": request.limit,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/api/collect/targets")
|
||||
async def get_collection_targets():
|
||||
"""Get available targets for individual collection."""
|
||||
db = Database()
|
||||
|
||||
favorite_authors = db.get_favorite_authors()
|
||||
subreddits = config_manager.get_subreddits()
|
||||
users = config_manager.get_users()
|
||||
|
||||
return {
|
||||
"favorite_authors": favorite_authors,
|
||||
"configured_subreddits": [s["name"] for s in subreddits],
|
||||
"configured_users": [u["name"] for u in users],
|
||||
}
|
||||
56
src/web/routers/stats.py
Normal file
56
src/web/routers/stats.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""Statistics API routes."""
|
||||
|
||||
import shutil
|
||||
|
||||
from fastapi import APIRouter, Query
|
||||
|
||||
from ...database import Database
|
||||
from ..deps import DOWNLOADS_DIR
|
||||
|
||||
router = APIRouter(tags=["stats"])
|
||||
|
||||
|
||||
@router.get("/api/stats")
|
||||
async def get_stats():
|
||||
"""Get collection statistics."""
|
||||
db = Database()
|
||||
stats = db.get_stats()
|
||||
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
if DOWNLOADS_DIR.exists():
|
||||
for f in DOWNLOADS_DIR.iterdir():
|
||||
if f.is_file() and not f.name.endswith(".json"):
|
||||
total_size += f.stat().st_size
|
||||
file_count += 1
|
||||
|
||||
stats["disk_size_bytes"] = total_size
|
||||
stats["disk_size_mb"] = round(total_size / (1024 * 1024), 2)
|
||||
stats["disk_size_gb"] = round(total_size / (1024 * 1024 * 1024), 2)
|
||||
stats["file_count"] = file_count
|
||||
|
||||
try:
|
||||
disk_usage = shutil.disk_usage(DOWNLOADS_DIR)
|
||||
stats["disk_free_gb"] = round(disk_usage.free / (1024 * 1024 * 1024), 2)
|
||||
stats["disk_total_gb"] = round(disk_usage.total / (1024 * 1024 * 1024), 2)
|
||||
stats["disk_used_percent"] = round((disk_usage.used / disk_usage.total) * 100, 1)
|
||||
except OSError:
|
||||
stats["disk_free_gb"] = 0
|
||||
stats["disk_total_gb"] = 0
|
||||
stats["disk_used_percent"] = 0
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/api/stats/enhanced")
|
||||
async def get_enhanced_stats():
|
||||
"""Get enhanced statistics for dashboard."""
|
||||
db = Database()
|
||||
return db.get_enhanced_stats()
|
||||
|
||||
|
||||
@router.get("/api/stats/recent")
|
||||
async def get_recent_downloads(limit: int = Query(default=10, le=50)):
|
||||
"""Get recent downloads."""
|
||||
db = Database()
|
||||
return db.get_recent_downloads(limit)
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
96
tests/conftest.py
Normal file
96
tests/conftest.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""Shared test fixtures."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import (
|
||||
BlacklistConfig,
|
||||
Config,
|
||||
DownloadConfig,
|
||||
LoggingConfig,
|
||||
RateLimitConfig,
|
||||
SubredditTarget,
|
||||
TargetsConfig,
|
||||
UserTarget,
|
||||
)
|
||||
from src.database import Database
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_dir(tmp_path):
|
||||
"""Provide a temporary directory."""
|
||||
return tmp_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(tmp_path):
|
||||
"""Provide a fresh test database."""
|
||||
db_path = tmp_path / "test_media.db"
|
||||
return Database(db_path=str(db_path))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config():
|
||||
"""Provide a sample configuration."""
|
||||
return Config(
|
||||
targets=TargetsConfig(
|
||||
subreddits=[SubredditTarget(name="pics", limit=25, sort="hot")],
|
||||
users=[UserTarget(name="testuser", limit=10)],
|
||||
),
|
||||
download=DownloadConfig(
|
||||
output_dir="./test_downloads",
|
||||
media_types=["image", "video", "gif"],
|
||||
min_score=10,
|
||||
skip_nsfw=True,
|
||||
max_file_size_mb=100,
|
||||
),
|
||||
rate_limit=RateLimitConfig(requests_per_minute=10, download_delay_seconds=0.1),
|
||||
logging=LoggingConfig(level="DEBUG", file=None),
|
||||
blacklist=BlacklistConfig(
|
||||
authors=["spammer"],
|
||||
subreddits=["spam_sub"],
|
||||
title_keywords=["buy now"],
|
||||
domains=["malware.com"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_file(tmp_path):
|
||||
"""Provide a temporary config file."""
|
||||
config_content = """
|
||||
targets:
|
||||
subreddits:
|
||||
- name: "pics"
|
||||
limit: 25
|
||||
sort: "hot"
|
||||
users:
|
||||
- name: "testuser"
|
||||
limit: 10
|
||||
|
||||
download:
|
||||
output_dir: "./downloads"
|
||||
media_types:
|
||||
- "image"
|
||||
- "video"
|
||||
min_score: 10
|
||||
skip_nsfw: true
|
||||
max_file_size_mb: 100
|
||||
|
||||
rate_limit:
|
||||
requests_per_minute: 10
|
||||
download_delay_seconds: 2
|
||||
|
||||
logging:
|
||||
level: "INFO"
|
||||
file: null
|
||||
|
||||
blacklist:
|
||||
authors:
|
||||
- "spammer"
|
||||
subreddits: []
|
||||
title_keywords: []
|
||||
domains: []
|
||||
"""
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text(config_content)
|
||||
return str(config_path)
|
||||
76
tests/test_config.py
Normal file
76
tests/test_config.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
"""Tests for configuration loading and validation."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.config import load_config, setup_logging
|
||||
|
||||
|
||||
class TestLoadConfig:
|
||||
def test_load_valid_config(self, config_file):
|
||||
config = load_config(config_file)
|
||||
assert config.targets.subreddits[0].name == "pics"
|
||||
assert config.targets.subreddits[0].limit == 25
|
||||
assert config.targets.users[0].name == "testuser"
|
||||
assert config.download.min_score == 10
|
||||
assert config.download.skip_nsfw is True
|
||||
assert config.rate_limit.requests_per_minute == 10
|
||||
|
||||
def test_missing_config_raises(self):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_config("nonexistent.yaml")
|
||||
|
||||
def test_empty_targets_raises(self, tmp_path):
|
||||
config_path = tmp_path / "empty.yaml"
|
||||
config_path.write_text("targets:\n subreddits: []\n users: []\n")
|
||||
with pytest.raises(ValueError, match="No targets configured"):
|
||||
load_config(str(config_path))
|
||||
|
||||
def test_blacklist_lowercase(self, tmp_path):
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("""
|
||||
targets:
|
||||
subreddits:
|
||||
- name: "test"
|
||||
limit: 10
|
||||
download: {}
|
||||
rate_limit: {}
|
||||
logging: {}
|
||||
blacklist:
|
||||
authors:
|
||||
- "SpAmMeR"
|
||||
subreddits:
|
||||
- "BadSub"
|
||||
title_keywords:
|
||||
- "BUY NOW"
|
||||
domains:
|
||||
- "Evil.COM"
|
||||
""")
|
||||
config = load_config(str(config_path))
|
||||
assert config.blacklist.authors == ["spammer"]
|
||||
assert config.blacklist.subreddits == ["badsub"]
|
||||
assert config.blacklist.title_keywords == ["buy now"]
|
||||
assert config.blacklist.domains == ["evil.com"]
|
||||
|
||||
def test_default_values(self, tmp_path):
|
||||
config_path = tmp_path / "minimal.yaml"
|
||||
config_path.write_text("""
|
||||
targets:
|
||||
subreddits:
|
||||
- name: "test"
|
||||
""")
|
||||
config = load_config(str(config_path))
|
||||
assert config.download.output_dir == "./downloads"
|
||||
assert config.download.min_score == 10
|
||||
assert config.rate_limit.requests_per_minute == 10
|
||||
|
||||
|
||||
class TestSetupLogging:
|
||||
def test_creates_logger(self, sample_config):
|
||||
logger = setup_logging(sample_config.logging)
|
||||
assert logger.name == "reddit_collector"
|
||||
|
||||
def test_log_level(self, sample_config):
|
||||
import logging
|
||||
|
||||
logger = setup_logging(sample_config.logging)
|
||||
assert logger.level == logging.DEBUG
|
||||
216
tests/test_database.py
Normal file
216
tests/test_database.py
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
"""Tests for database operations."""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from src.database import PostRecord
|
||||
|
||||
|
||||
def _make_post(post_id="test123", **kwargs):
|
||||
"""Helper to create a PostRecord with defaults."""
|
||||
defaults = {
|
||||
"id": post_id,
|
||||
"subreddit": "pics",
|
||||
"author": "testuser",
|
||||
"title": "Test Post",
|
||||
"url": "https://reddit.com/test",
|
||||
"media_url": "https://i.redd.it/test.jpg",
|
||||
"media_type": "image",
|
||||
"score": 100,
|
||||
"created_utc": 1700000000.0,
|
||||
"downloaded_at": None,
|
||||
"local_path": None,
|
||||
"file_hash": None,
|
||||
"permalink": "/r/pics/comments/test123/test_post/",
|
||||
"source_type": "subreddit",
|
||||
"flair": "OC",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return PostRecord(**defaults)
|
||||
|
||||
|
||||
class TestDatabase:
|
||||
def test_init_creates_tables(self, db):
|
||||
with db._get_connection() as conn:
|
||||
cursor = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = {row[0] for row in cursor.fetchall()}
|
||||
assert "posts" in tables
|
||||
assert "favorites" in tables
|
||||
assert "scheduler_history" in tables
|
||||
|
||||
def test_add_and_get_post(self, db):
|
||||
post = _make_post()
|
||||
db.add_post(post)
|
||||
retrieved = db.get_post("test123")
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == "test123"
|
||||
assert retrieved.subreddit == "pics"
|
||||
assert retrieved.author == "testuser"
|
||||
|
||||
def test_post_exists(self, db):
|
||||
assert db.post_exists("test123") is False
|
||||
db.add_post(_make_post())
|
||||
assert db.post_exists("test123") is True
|
||||
|
||||
def test_mark_downloaded(self, db):
|
||||
db.add_post(_make_post())
|
||||
db.mark_downloaded("test123", "/path/to/file.jpg", "abc123hash")
|
||||
post = db.get_post("test123")
|
||||
assert post.local_path == "/path/to/file.jpg"
|
||||
assert post.file_hash == "abc123hash"
|
||||
assert post.downloaded_at is not None
|
||||
|
||||
def test_hash_exists(self, db):
|
||||
assert db.hash_exists("abc123hash") is None
|
||||
post = _make_post()
|
||||
db.add_post(post)
|
||||
db.mark_downloaded("test123", "/path/to/file.jpg", "abc123hash")
|
||||
assert db.hash_exists("abc123hash") == "/path/to/file.jpg"
|
||||
|
||||
def test_get_stats(self, db):
|
||||
db.add_post(_make_post("p1"))
|
||||
db.mark_downloaded("p1", "/path/p1.jpg", "hash1")
|
||||
db.add_post(_make_post("p2", media_type="video"))
|
||||
db.mark_downloaded("p2", "/path/p2.mp4", "hash2")
|
||||
|
||||
stats = db.get_stats()
|
||||
assert stats["total_posts"] == 2
|
||||
assert stats["downloaded"] == 2
|
||||
|
||||
def test_get_media_files_sorting(self, db):
|
||||
db.add_post(_make_post("p1", score=10, created_utc=1000.0))
|
||||
db.mark_downloaded("p1", "/p1.jpg", "h1")
|
||||
db.add_post(_make_post("p2", score=500, created_utc=2000.0))
|
||||
db.mark_downloaded("p2", "/p2.jpg", "h2")
|
||||
|
||||
newest = db.get_media_files(sort="newest")
|
||||
assert newest[0]["id"] == "p2"
|
||||
|
||||
oldest = db.get_media_files(sort="oldest")
|
||||
assert oldest[0]["id"] == "p1"
|
||||
|
||||
score_high = db.get_media_files(sort="score_high")
|
||||
assert score_high[0]["id"] == "p2"
|
||||
|
||||
def test_get_media_files_filtering(self, db):
|
||||
db.add_post(_make_post("p1", subreddit="pics"))
|
||||
db.mark_downloaded("p1", "/p1.jpg", "h1")
|
||||
db.add_post(_make_post("p2", subreddit="videos", media_type="video"))
|
||||
db.mark_downloaded("p2", "/p2.mp4", "h2")
|
||||
|
||||
pics_only = db.get_media_files(subreddit="pics")
|
||||
assert len(pics_only) == 1
|
||||
assert pics_only[0]["subreddit"] == "pics"
|
||||
|
||||
videos_only = db.get_media_files(media_type="video")
|
||||
assert len(videos_only) == 1
|
||||
|
||||
def test_total_media_count(self, db):
|
||||
db.add_post(_make_post("p1"))
|
||||
db.mark_downloaded("p1", "/p1.jpg", "h1")
|
||||
db.add_post(_make_post("p2"))
|
||||
db.mark_downloaded("p2", "/p2.jpg", "h2")
|
||||
|
||||
assert db.get_total_media_count() == 2
|
||||
assert db.get_total_media_count(subreddit="pics") == 2
|
||||
assert db.get_total_media_count(subreddit="nonexistent") == 0
|
||||
|
||||
|
||||
class TestFavorites:
|
||||
def test_add_favorite(self, db):
|
||||
db.add_post(_make_post())
|
||||
assert db.add_favorite("test123") is True
|
||||
assert db.add_favorite("test123") is False # duplicate
|
||||
|
||||
def test_remove_favorite(self, db):
|
||||
db.add_post(_make_post())
|
||||
db.add_favorite("test123")
|
||||
assert db.remove_favorite("test123") is True
|
||||
assert db.remove_favorite("test123") is False
|
||||
|
||||
def test_is_favorite(self, db):
|
||||
db.add_post(_make_post())
|
||||
assert db.is_favorite("test123") is False
|
||||
db.add_favorite("test123")
|
||||
assert db.is_favorite("test123") is True
|
||||
|
||||
def test_get_favorites(self, db):
|
||||
db.add_post(_make_post("p1"))
|
||||
db.mark_downloaded("p1", "/p1.jpg", "h1")
|
||||
db.add_post(_make_post("p2"))
|
||||
db.mark_downloaded("p2", "/p2.jpg", "h2")
|
||||
db.add_favorite("p1")
|
||||
|
||||
favs = db.get_favorites()
|
||||
assert len(favs) == 1
|
||||
assert favs[0]["id"] == "p1"
|
||||
|
||||
def test_count_favorites(self, db):
|
||||
db.add_post(_make_post("p1"))
|
||||
db.add_post(_make_post("p2"))
|
||||
db.add_favorite("p1")
|
||||
db.add_favorite("p2")
|
||||
assert db.count_favorites() == 2
|
||||
|
||||
def test_get_favorite_authors(self, db):
|
||||
db.add_post(_make_post("p1", author="alice"))
|
||||
db.add_post(_make_post("p2", author="bob"))
|
||||
db.add_favorite("p1")
|
||||
|
||||
authors = db.get_favorite_authors()
|
||||
assert "alice" in authors
|
||||
assert "bob" not in authors
|
||||
|
||||
|
||||
class TestSchedulerHistory:
|
||||
def test_add_and_finish_run(self, db):
|
||||
run_id = db.add_scheduler_run(datetime.now())
|
||||
assert run_id > 0
|
||||
|
||||
db.finish_scheduler_run(run_id, "success", 10, 5)
|
||||
history = db.get_scheduler_history()
|
||||
assert len(history) == 1
|
||||
assert history[0]["status"] == "success"
|
||||
assert history[0]["posts_processed"] == 10
|
||||
|
||||
def test_get_last_scheduler_run(self, db):
|
||||
assert db.get_last_scheduler_run() is None
|
||||
|
||||
db.add_scheduler_run(datetime.now())
|
||||
assert db.get_last_scheduler_run() is not None
|
||||
|
||||
|
||||
class TestPostsByAuthorsAndSubreddits:
|
||||
def test_get_posts_by_authors(self, db):
|
||||
db.add_post(_make_post("p1", author="alice"))
|
||||
db.mark_downloaded("p1", "/p1.jpg", "h1")
|
||||
db.add_post(_make_post("p2", author="bob"))
|
||||
db.mark_downloaded("p2", "/p2.jpg", "h2")
|
||||
|
||||
posts = db.get_posts_by_authors(["alice"])
|
||||
assert len(posts) == 1
|
||||
assert posts[0].author == "alice"
|
||||
|
||||
def test_case_insensitive_author_search(self, db):
|
||||
db.add_post(_make_post("p1", author="Alice"))
|
||||
db.mark_downloaded("p1", "/p1.jpg", "h1")
|
||||
|
||||
posts = db.get_posts_by_authors(["alice"])
|
||||
assert len(posts) == 1
|
||||
|
||||
def test_count_posts_by_authors(self, db):
|
||||
db.add_post(_make_post("p1", author="alice"))
|
||||
db.mark_downloaded("p1", "/p1.jpg", "h1")
|
||||
db.add_post(_make_post("p2", author="alice"))
|
||||
db.mark_downloaded("p2", "/p2.jpg", "h2")
|
||||
|
||||
assert db.count_posts_by_authors(["alice"]) == 2
|
||||
assert db.count_posts_by_authors(["bob"]) == 0
|
||||
assert db.count_posts_by_authors([]) == 0
|
||||
|
||||
def test_get_posts_by_subreddits(self, db):
|
||||
db.add_post(_make_post("p1", subreddit="pics"))
|
||||
db.mark_downloaded("p1", "/p1.jpg", "h1")
|
||||
|
||||
posts = db.get_posts_by_subreddits(["pics"])
|
||||
assert len(posts) == 1
|
||||
assert posts[0].subreddit == "pics"
|
||||
47
tests/test_downloader.py
Normal file
47
tests/test_downloader.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""Tests for downloader module."""
|
||||
|
||||
import hashlib
|
||||
|
||||
from src.config import DownloadConfig, RateLimitConfig
|
||||
from src.downloader import Downloader
|
||||
|
||||
|
||||
class TestDownloader:
|
||||
def test_compute_hash(self, tmp_path):
|
||||
config = DownloadConfig(output_dir=str(tmp_path))
|
||||
rate = RateLimitConfig(download_delay_seconds=0)
|
||||
downloader = Downloader(config, rate)
|
||||
|
||||
test_file = tmp_path / "test.txt"
|
||||
test_file.write_bytes(b"hello world")
|
||||
|
||||
file_hash = downloader.compute_hash(str(test_file))
|
||||
expected = hashlib.md5(b"hello world").hexdigest()
|
||||
assert file_hash == expected
|
||||
|
||||
def test_get_extension_from_url(self, tmp_path):
|
||||
config = DownloadConfig(output_dir=str(tmp_path))
|
||||
rate = RateLimitConfig(download_delay_seconds=0)
|
||||
downloader = Downloader(config, rate)
|
||||
|
||||
assert downloader._get_extension("https://example.com/image.jpg", None) == ".jpg"
|
||||
assert downloader._get_extension("https://example.com/image.png", None) == ".png"
|
||||
assert downloader._get_extension("https://example.com/video.mp4", None) == ".mp4"
|
||||
|
||||
def test_get_extension_from_content_type(self, tmp_path):
|
||||
config = DownloadConfig(output_dir=str(tmp_path))
|
||||
rate = RateLimitConfig(download_delay_seconds=0)
|
||||
downloader = Downloader(config, rate)
|
||||
|
||||
assert downloader._get_extension("https://example.com/blah", "image/jpeg") == ".jpg"
|
||||
assert downloader._get_extension("https://example.com/blah", "image/png") == ".png"
|
||||
assert downloader._get_extension("https://example.com/blah", "video/mp4") == ".mp4"
|
||||
|
||||
def test_sanitize_name(self, tmp_path):
|
||||
config = DownloadConfig(output_dir=str(tmp_path))
|
||||
rate = RateLimitConfig(download_delay_seconds=0)
|
||||
downloader = Downloader(config, rate)
|
||||
|
||||
assert downloader._sanitize_name("normal_name") == "normal_name"
|
||||
assert downloader._sanitize_name("has spaces!@#") == "has_spaces___"
|
||||
assert downloader._sanitize_name("with-dash") == "with-dash"
|
||||
48
tests/test_extractors.py
Normal file
48
tests/test_extractors.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Tests for URL extractors."""
|
||||
|
||||
from src.extractors import extract_media_url
|
||||
from src.extractors.imgur import extract_imgur_url
|
||||
|
||||
|
||||
class TestExtractMediaUrl:
|
||||
def test_direct_image_passthrough(self):
|
||||
url = "https://i.redd.it/test123.jpg"
|
||||
result_url, result_type = extract_media_url(url, "image")
|
||||
assert result_url == url
|
||||
assert result_type == "image"
|
||||
|
||||
def test_unknown_domain_passthrough(self):
|
||||
url = "https://example.com/image.jpg"
|
||||
result_url, result_type = extract_media_url(url, "image")
|
||||
assert result_url == url
|
||||
assert result_type == "image"
|
||||
|
||||
|
||||
class TestImgurExtractor:
|
||||
def test_direct_imgur_url(self):
|
||||
url = "https://i.imgur.com/abc123.jpg"
|
||||
result_url, result_type = extract_imgur_url(url)
|
||||
assert result_url == url
|
||||
assert result_type == "image"
|
||||
|
||||
def test_gifv_to_mp4(self):
|
||||
url = "https://i.imgur.com/abc123.gifv"
|
||||
result_url, result_type = extract_imgur_url(url)
|
||||
assert result_url == "https://i.imgur.com/abc123.mp4"
|
||||
assert result_type == "video"
|
||||
|
||||
def test_imgur_page_url(self):
|
||||
url = "https://imgur.com/abc123"
|
||||
result_url, result_type = extract_imgur_url(url)
|
||||
assert result_url == "https://i.imgur.com/abc123.jpg"
|
||||
assert result_type == "image"
|
||||
|
||||
def test_imgur_album_unsupported(self):
|
||||
url = "https://imgur.com/a/abc123"
|
||||
result_url, _result_type = extract_imgur_url(url)
|
||||
assert result_url is None
|
||||
|
||||
def test_imgur_gallery_unsupported(self):
|
||||
url = "https://imgur.com/gallery/abc123"
|
||||
result_url, _result_type = extract_imgur_url(url)
|
||||
assert result_url is None
|
||||
87
tests/test_main.py
Normal file
87
tests/test_main.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
"""Tests for main collector logic."""
|
||||
|
||||
from src.main import is_domain_blacklisted, should_download_media, should_download_post
|
||||
from src.reddit_client import Post
|
||||
|
||||
|
||||
def _make_reddit_post(**kwargs):
|
||||
"""Helper to create a Post with defaults."""
|
||||
defaults = {
|
||||
"id": "test123",
|
||||
"subreddit": "pics",
|
||||
"author": "testuser",
|
||||
"title": "Test Post",
|
||||
"url": "https://i.redd.it/test.jpg",
|
||||
"score": 100,
|
||||
"created_utc": 1700000000.0,
|
||||
"over_18": False,
|
||||
"is_gallery": False,
|
||||
"preview": None,
|
||||
"media_metadata": None,
|
||||
"permalink": "/r/pics/test123",
|
||||
"flair": None,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return Post(**defaults)
|
||||
|
||||
|
||||
class TestShouldDownloadPost:
|
||||
def test_normal_post_allowed(self, sample_config):
|
||||
post = _make_reddit_post()
|
||||
allowed, reason = should_download_post(post, sample_config)
|
||||
assert allowed is True
|
||||
assert reason == ""
|
||||
|
||||
def test_nsfw_skipped(self, sample_config):
|
||||
post = _make_reddit_post(over_18=True)
|
||||
allowed, reason = should_download_post(post, sample_config)
|
||||
assert allowed is False
|
||||
assert reason == "nsfw"
|
||||
|
||||
def test_low_score_skipped(self, sample_config):
|
||||
post = _make_reddit_post(score=1)
|
||||
allowed, reason = should_download_post(post, sample_config)
|
||||
assert allowed is False
|
||||
assert reason == "score"
|
||||
|
||||
def test_blacklisted_author(self, sample_config):
|
||||
post = _make_reddit_post(author="spammer")
|
||||
allowed, reason = should_download_post(post, sample_config)
|
||||
assert allowed is False
|
||||
assert reason == "blacklist_author"
|
||||
|
||||
def test_blacklisted_subreddit(self, sample_config):
|
||||
post = _make_reddit_post(subreddit="spam_sub")
|
||||
allowed, reason = should_download_post(post, sample_config)
|
||||
assert allowed is False
|
||||
assert reason == "blacklist_subreddit"
|
||||
|
||||
def test_blacklisted_keyword(self, sample_config):
|
||||
post = _make_reddit_post(title="Amazing! Buy now for cheap!")
|
||||
allowed, reason = should_download_post(post, sample_config)
|
||||
assert allowed is False
|
||||
assert reason == "blacklist_keyword"
|
||||
|
||||
|
||||
class TestIsDomainBlacklisted:
|
||||
def test_blacklisted_domain(self):
|
||||
assert is_domain_blacklisted("https://malware.com/image.jpg", ["malware.com"]) is True
|
||||
|
||||
def test_clean_domain(self):
|
||||
assert is_domain_blacklisted("https://i.redd.it/image.jpg", ["malware.com"]) is False
|
||||
|
||||
def test_empty_blacklist(self):
|
||||
assert is_domain_blacklisted("https://example.com", []) is False
|
||||
|
||||
def test_empty_url(self):
|
||||
assert is_domain_blacklisted("", ["malware.com"]) is False
|
||||
|
||||
|
||||
class TestShouldDownloadMedia:
|
||||
def test_allowed_type(self, sample_config):
|
||||
assert should_download_media("image", sample_config) is True
|
||||
assert should_download_media("video", sample_config) is True
|
||||
|
||||
def test_disallowed_type(self, sample_config):
|
||||
sample_config.download.media_types = ["image"]
|
||||
assert should_download_media("video", sample_config) is False
|
||||
43
tests/test_reddit_client.py
Normal file
43
tests/test_reddit_client.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
"""Tests for Reddit client."""
|
||||
|
||||
import time
|
||||
|
||||
from src.reddit_client import Post, RateLimiter
|
||||
|
||||
|
||||
class TestRateLimiter:
|
||||
def test_first_request_no_wait(self):
|
||||
limiter = RateLimiter(requests_per_minute=60)
|
||||
start = time.time()
|
||||
limiter.wait()
|
||||
elapsed = time.time() - start
|
||||
assert elapsed < 0.1
|
||||
|
||||
def test_respects_rate_limit(self):
|
||||
limiter = RateLimiter(requests_per_minute=120) # 0.5s interval
|
||||
limiter.wait()
|
||||
limiter.last_request = time.time() # simulate request just happened
|
||||
start = time.time()
|
||||
limiter.wait()
|
||||
elapsed = time.time() - start
|
||||
assert elapsed >= 0.4 # Should wait ~0.5s
|
||||
|
||||
|
||||
class TestPost:
|
||||
def test_post_creation(self):
|
||||
post = Post(
|
||||
id="test",
|
||||
subreddit="pics",
|
||||
author="user",
|
||||
title="Title",
|
||||
url="https://example.com",
|
||||
score=100,
|
||||
created_utc=1700000000.0,
|
||||
over_18=False,
|
||||
is_gallery=False,
|
||||
preview=None,
|
||||
media_metadata=None,
|
||||
)
|
||||
assert post.id == "test"
|
||||
assert post.permalink is None
|
||||
assert post.flair is None
|
||||
147
tests/test_sidecar.py
Normal file
147
tests/test_sidecar.py
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
"""Tests for sidecar file generation."""
|
||||
|
||||
import json
|
||||
|
||||
from src.sidecar import generate_filename, write_immich_sidecar
|
||||
|
||||
|
||||
class TestWriteImmichSidecar:
|
||||
def test_creates_sidecar_file(self, tmp_path):
|
||||
media_file = tmp_path / "test.jpg"
|
||||
media_file.write_bytes(b"fake image data")
|
||||
|
||||
sidecar_path = write_immich_sidecar(
|
||||
filepath=str(media_file),
|
||||
subreddit="pics",
|
||||
author="testuser",
|
||||
title="A great photo",
|
||||
score=150,
|
||||
created_utc=1700000000.0,
|
||||
media_type="image",
|
||||
permalink="/r/pics/comments/abc/a_great_photo/",
|
||||
flair="OC",
|
||||
source_type="subreddit",
|
||||
)
|
||||
|
||||
assert sidecar_path.endswith(".json")
|
||||
with open(sidecar_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
assert "dateTimeOriginal" in data
|
||||
assert data["description"] == "A great photo"
|
||||
assert data["albums"] == ["r/pics"]
|
||||
assert "reddit" in data["tags"]
|
||||
assert "pics" in data["tags"]
|
||||
assert "OC" in data["tags"]
|
||||
assert data["rating"] == 3 # score 150 -> rating 3
|
||||
assert data["people"] == ["testuser"]
|
||||
assert "reddit.com" in data["externalUrl"]
|
||||
|
||||
def test_rating_buckets(self, tmp_path):
|
||||
media_file = tmp_path / "test.jpg"
|
||||
media_file.write_bytes(b"fake")
|
||||
|
||||
test_cases = [
|
||||
(5, 1),
|
||||
(10, 2),
|
||||
(50, 3),
|
||||
(200, 4),
|
||||
(1000, 5),
|
||||
(5000, 5),
|
||||
]
|
||||
for score, expected_rating in test_cases:
|
||||
path = write_immich_sidecar(
|
||||
filepath=str(media_file),
|
||||
subreddit="test",
|
||||
author="user",
|
||||
title="test",
|
||||
score=score,
|
||||
created_utc=1700000000.0,
|
||||
media_type="image",
|
||||
)
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
assert data["rating"] == expected_rating, f"Score {score} should give rating {expected_rating}"
|
||||
|
||||
def test_deleted_author_no_people(self, tmp_path):
|
||||
media_file = tmp_path / "test.jpg"
|
||||
media_file.write_bytes(b"fake")
|
||||
|
||||
write_immich_sidecar(
|
||||
filepath=str(media_file),
|
||||
subreddit="test",
|
||||
author="[deleted]",
|
||||
title="test",
|
||||
score=10,
|
||||
created_utc=1700000000.0,
|
||||
media_type="image",
|
||||
)
|
||||
with open(str(media_file) + ".json") as f:
|
||||
data = json.load(f)
|
||||
assert "people" not in data
|
||||
|
||||
def test_no_permalink_no_external_url(self, tmp_path):
|
||||
media_file = tmp_path / "test.jpg"
|
||||
media_file.write_bytes(b"fake")
|
||||
|
||||
write_immich_sidecar(
|
||||
filepath=str(media_file),
|
||||
subreddit="test",
|
||||
author="user",
|
||||
title="test",
|
||||
score=10,
|
||||
created_utc=1700000000.0,
|
||||
media_type="image",
|
||||
)
|
||||
with open(str(media_file) + ".json") as f:
|
||||
data = json.load(f)
|
||||
assert "externalUrl" not in data
|
||||
|
||||
|
||||
class TestGenerateFilename:
|
||||
def test_basic_filename(self):
|
||||
name = generate_filename(
|
||||
subreddit="pics",
|
||||
author="testuser",
|
||||
created_utc=1700000000.0,
|
||||
post_id="abc123",
|
||||
ext=".jpg",
|
||||
)
|
||||
assert "pics" in name
|
||||
assert "testuser" in name
|
||||
assert "abc123" in name
|
||||
assert name.endswith(".jpg")
|
||||
|
||||
def test_gallery_index(self):
|
||||
name = generate_filename(
|
||||
subreddit="pics",
|
||||
author="testuser",
|
||||
created_utc=1700000000.0,
|
||||
post_id="abc123",
|
||||
ext=".jpg",
|
||||
gallery_index=3,
|
||||
)
|
||||
assert "_3.jpg" in name
|
||||
|
||||
def test_deleted_author(self):
|
||||
name = generate_filename(
|
||||
subreddit="pics",
|
||||
author="[deleted]",
|
||||
created_utc=1700000000.0,
|
||||
post_id="abc123",
|
||||
ext=".jpg",
|
||||
)
|
||||
assert "unknown" in name
|
||||
|
||||
def test_sanitized_names(self):
|
||||
name = generate_filename(
|
||||
subreddit="pics/test",
|
||||
author="user name!@#",
|
||||
created_utc=1700000000.0,
|
||||
post_id="abc123",
|
||||
ext=".jpg",
|
||||
)
|
||||
# Should not contain special characters
|
||||
assert "/" not in name
|
||||
assert "!" not in name
|
||||
assert "@" not in name
|
||||
84
tests/test_web_api.py
Normal file
84
tests/test_web_api.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
"""Tests for web API endpoints."""
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from src.web.app import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
class TestHealthEndpoints:
|
||||
def test_root_returns_html(self):
|
||||
response = client.get("/")
|
||||
assert response.status_code == 200
|
||||
assert "text/html" in response.headers["content-type"]
|
||||
|
||||
def test_get_config(self):
|
||||
response = client.get("/api/config")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_stats(self):
|
||||
response = client.get("/api/stats")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "total_posts" in data
|
||||
|
||||
def test_get_collector_status(self):
|
||||
response = client.get("/api/collector/status")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "running" in data
|
||||
|
||||
def test_get_media_files(self):
|
||||
response = client.get("/api/media?limit=10")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "files" in data
|
||||
assert "total" in data
|
||||
|
||||
def test_get_subreddits(self):
|
||||
response = client.get("/api/subreddits")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_users(self):
|
||||
response = client.get("/api/users")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_blacklist(self):
|
||||
response = client.get("/api/blacklist")
|
||||
assert response.status_code == 200
|
||||
|
||||
def test_get_settings(self):
|
||||
response = client.get("/api/settings")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestMediaEndpoints:
|
||||
def test_file_not_found(self):
|
||||
response = client.get("/api/media/file/nonexistent.jpg")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_thumb_not_found(self):
|
||||
response = client.get("/api/media/thumb/nonexistent.mp4")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_media_info_not_found(self):
|
||||
response = client.get("/api/media/nonexistent/info")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_delete_media_not_found(self):
|
||||
response = client.delete("/api/media/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
class TestValidation:
|
||||
def test_cleanup_invalid_media_type(self):
|
||||
response = client.get("/api/media/cleanup-preview?media_type=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_collect_invalid_target_type(self):
|
||||
response = client.post(
|
||||
"/api/collect/individual",
|
||||
json={"target_type": "invalid", "target_name": "test"},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
Loading…
Add table
Add a link
Reference in a new issue