security: fix auth, CORS, file upload, endpoint hardening + test fixes

- Add INTERNAL_SECRET shared-secret auth to /internal/nc-upload endpoint
- Add JWT token validation to WebSocket /ws/versions/{version_id}
- Fix NameError: band_slug → band.slug in internal.py
- Move inline imports to top of internal.py; add missing Member/NextcloudClient imports
- Remove ~15 debug print() statements from auth.py
- Replace Content-Type-only avatar check with extension whitelist + Pillow Image.verify()
- Sanitize exception details in versions.py (no more str(e) in 4xx/5xx responses)
- Restrict CORS allow_methods/allow_headers from "*" to explicit lists
- Add security headers middleware: X-Frame-Options, X-Content-Type-Options, Referrer-Policy
- Reduce JWT expiry from 7 days to 1 hour
- Add Pillow>=10.0 dependency; document INTERNAL_SECRET in .env.example
- Implement missing RedisJobQueue.dequeue() method (required by protocol)
- Fix 5 pre-existing unit test failures: settings env vars conftest, deferred Redis push,
  dequeue method, AsyncMock→MagicMock for sync scalar_one_or_none

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Mistral Vibe
2026-03-30 21:02:56 +02:00
parent efef818612
commit 68da26588a
12 changed files with 161 additions and 98 deletions

View File

@@ -1,6 +1,9 @@
# ── Security ──────────────────────────────────────────────────────────────────
# Generate with: openssl rand -hex 32
SECRET_KEY=replace_me_with_32_byte_hex
# Shared secret for internal service-to-service calls (nc-watcher → API)
# Generate with: openssl rand -hex 32
INTERNAL_SECRET=replace_me_with_32_byte_hex
# ── Domain ────────────────────────────────────────────────────────────────────
DOMAIN=yourdomain.com

View File

@@ -19,6 +19,7 @@ dependencies = [
"httpx>=0.27",
"redis[hiredis]>=5.0",
"python-multipart>=0.0.9",
"Pillow>=10.0",
]
[project.optional-dependencies]

View File

@@ -7,8 +7,9 @@ class Settings(BaseSettings):
# Security
secret_key: str
internal_secret: str # Shared secret for internal service-to-service calls
jwt_algorithm: str = "HS256"
access_token_expire_minutes: int = 60 * 24 * 7 # 7 days
access_token_expire_minutes: int = 60 # 1 hour
# Database
database_url: str # postgresql+asyncpg://...

View File

@@ -3,7 +3,7 @@
from contextlib import asynccontextmanager
import os
from fastapi import FastAPI
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
@@ -47,10 +47,19 @@ def create_app() -> FastAPI:
CORSMiddleware,
allow_origins=[f"https://{settings.domain}", "http://localhost:3000"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
allow_headers=["Authorization", "Content-Type", "Accept"],
)
@app.middleware("http")
async def security_headers(request: Request, call_next) -> Response:
response = await call_next(request)
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["X-XSS-Protection"] = "0" # Disable legacy XSS auditor; rely on CSP
return response
prefix = "/api/v1"
app.include_router(auth_router, prefix=prefix)
app.include_router(bands_router, prefix=prefix)

View File

@@ -79,6 +79,20 @@ class RedisJobQueue:
job.finished_at = datetime.now(timezone.utc)
await self._session.flush()
async def dequeue(self, timeout: int = 5) -> tuple[uuid.UUID, str, dict[str, Any]] | None:
"""Block up to `timeout` seconds for a job. Returns (id, type, payload) or None."""
redis_client = await self._get_redis()
queue_key = get_settings().job_queue_key
result = await redis_client.blpop(queue_key, timeout=timeout)
if result is None:
return None
_, raw_id = result
job_id = uuid.UUID(raw_id)
job = await self._session.get(Job, job_id)
if job is None:
return None
return job_id, job.type, job.payload # type: ignore[return-value]
async def close(self) -> None:
if self._redis:
await self._redis.aclose()

View File

@@ -1,8 +1,11 @@
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession
import logging
import os
import uuid
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from PIL import Image, UnidentifiedImageError
from sqlalchemy.ext.asyncio import AsyncSession
from rehearsalhub.db.engine import get_session
from rehearsalhub.db.models import Member
from rehearsalhub.dependencies import get_current_member
@@ -11,8 +14,13 @@ from rehearsalhub.schemas.auth import LoginRequest, RegisterRequest, TokenRespon
from rehearsalhub.schemas.member import MemberRead, MemberSettingsUpdate
from rehearsalhub.services.auth import AuthService
log = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
_ALLOWED_IMAGE_EXTENSIONS = {"jpg", "jpeg", "png", "gif", "webp"}
_MAX_AVATAR_SIZE = 5 * 1024 * 1024 # 5 MB
@router.post("/register", response_model=MemberRead, status_code=status.HTTP_201_CREATED)
async def register(req: RegisterRequest, session: AsyncSession = Depends(get_session)):
@@ -46,9 +54,6 @@ async def update_settings(
session: AsyncSession = Depends(get_session),
current_member: Member = Depends(get_current_member),
):
print(f"Update settings called for member {current_member.id}")
print(f"Update data: {data.model_dump()}")
repo = MemberRepository(session)
updates: dict = {}
if data.display_name is not None:
@@ -62,14 +67,10 @@ async def update_settings(
if data.avatar_url is not None:
updates["avatar_url"] = data.avatar_url or None
print(f"Updates to apply: {updates}")
if updates:
member = await repo.update(current_member, **updates)
print("Settings updated successfully")
else:
member = current_member
print("No updates to apply")
return MemberRead.from_model(member)
@@ -80,85 +81,68 @@ async def upload_avatar(
current_member: Member = Depends(get_current_member),
):
"""Upload and set user avatar image."""
print(f"Avatar upload called for member {current_member.id}")
print(f"File: {file.filename}, Content-Type: {file.content_type}")
# Validate file type
if not file.content_type.startswith("image/"):
print("Invalid file type")
# Validate extension against whitelist
raw_ext = (file.filename.rsplit(".", 1)[-1] if "." in (file.filename or "") else "").lower()
if raw_ext not in _ALLOWED_IMAGE_EXTENSIONS:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Only image files are allowed (JPG, PNG, GIF, etc.)"
detail="Only JPG, PNG, GIF, and WebP images are allowed.",
)
# Validate file size (5MB limit for upload endpoint)
max_size = 5 * 1024 * 1024 # 5MB
if file.size > max_size:
print(f"File too large: {file.size} bytes (max {max_size})")
# Validate file size
if file.size and file.size > _MAX_AVATAR_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File too large. Maximum size is {max_size / 1024 / 1024}MB. Please resize your image and try again."
detail=f"File too large. Maximum size is {_MAX_AVATAR_SIZE // 1024 // 1024} MB.",
)
# Create uploads directory if it doesn't exist
contents = await file.read()
if len(contents) > _MAX_AVATAR_SIZE:
raise HTTPException(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
detail=f"File too large. Maximum size is {_MAX_AVATAR_SIZE // 1024 // 1024} MB.",
)
if len(contents) == 0:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Empty file received.",
)
# Validate that the file is actually a valid image using Pillow
try:
import io
img = Image.open(io.BytesIO(contents))
img.verify()
except UnidentifiedImageError:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="File is not a valid image.",
)
except Exception:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Could not process image file.",
)
upload_dir = "uploads/avatars"
os.makedirs(upload_dir, exist_ok=True)
print(f"Using upload directory: {upload_dir}")
# Generate unique filename
file_ext = file.filename.split(".")[-1] if "." in file.filename else "jpg"
filename = f"{uuid.uuid4()}.{file_ext}"
filename = f"{uuid.uuid4()}.{raw_ext}"
file_path = f"{upload_dir}/{filename}"
print(f"Saving file to: {file_path}")
# Save file
try:
contents = await file.read()
print(f"File size: {len(contents)} bytes")
print(f"File content preview: {contents[:50]}...") # First 50 bytes for debugging
# Validate that we actually got content
if len(contents) == 0:
print("Empty file content received")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Empty file content received"
)
with open(file_path, "wb") as buffer:
buffer.write(contents)
print("File saved successfully")
# Verify file was saved
if not os.path.exists(file_path):
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to verify saved file"
)
file_size = os.path.getsize(file_path)
print(f"Saved file size: {file_size} bytes")
if file_size == 0:
os.remove(file_path)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Saved file is empty"
)
except Exception as e:
print(f"Failed to save file: {str(e)}")
except Exception:
log.exception("Failed to save avatar for member %s", current_member.id)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to save avatar: {str(e)}"
detail="Failed to save avatar.",
)
# Update member's avatar URL
repo = MemberRepository(session)
avatar_url = f"/api/static/avatars/{filename}"
print(f"Setting avatar URL to: {avatar_url}")
member = await repo.update(current_member, avatar_url=avatar_url)
print("Avatar updated successfully")
return MemberRead.from_model(member)

View File

@@ -3,11 +3,14 @@
import logging
from pathlib import Path
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from rehearsalhub.config import get_settings
from rehearsalhub.db.engine import get_session
from rehearsalhub.db.models import BandMember, Member
from rehearsalhub.repositories.audio_version import AudioVersionRepository
from rehearsalhub.repositories.band import BandRepository
from rehearsalhub.repositories.rehearsal_session import RehearsalSessionRepository
@@ -15,11 +18,19 @@ from rehearsalhub.repositories.song import SongRepository
from rehearsalhub.schemas.audio_version import AudioVersionCreate
from rehearsalhub.services.session import extract_session_folder, parse_rehearsal_date
from rehearsalhub.services.song import SongService
from rehearsalhub.storage.nextcloud import NextcloudClient
log = logging.getLogger(__name__)
router = APIRouter(prefix="/internal", tags=["internal"])
async def _verify_internal_secret(x_internal_token: str | None = Header(None)) -> None:
"""Verify the shared secret sent by internal services."""
settings = get_settings()
if x_internal_token != settings.internal_secret:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
AUDIO_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".opus"}
@@ -32,6 +43,7 @@ class NcUploadEvent(BaseModel):
async def nc_upload(
event: NcUploadEvent,
session: AsyncSession = Depends(get_session),
_: None = Depends(_verify_internal_secret),
):
"""
Called by nc-watcher when a new audio file is detected in Nextcloud.
@@ -105,13 +117,11 @@ async def nc_upload(
nc_folder_path=nc_folder,
created_by=None,
)
log.info("nc-upload: created song '%s' for band '%s'", title, band_slug)
log.info("nc-upload: created song '%s' for band '%s'", title, band.slug)
elif rehearsal_session_id and song.session_id is None:
song = await song_repo.update(song, session_id=rehearsal_session_id)
# Use first member of the band as uploader (best-effort for watcher uploads)
from sqlalchemy import select
from rehearsalhub.db.models import BandMember
result = await session.execute(
select(BandMember.member_id).where(BandMember.band_id == band.id).limit(1)
)
@@ -121,7 +131,7 @@ async def nc_upload(
storage = None
if uploader_id:
uploader_result = await session.execute(
select(Member).where(Member.id == uploader_id).limit(1)
select(Member).where(Member.id == uploader_id).limit(1) # type: ignore[arg-type]
)
uploader = uploader_result.scalar_one_or_none()
storage = NextcloudClient.for_member(uploader) if uploader else None

View File

@@ -189,26 +189,26 @@ async def get_waveform(
)
try:
data = await _download_with_retry(storage, version.waveform_url)
except httpx.ConnectError as e:
except httpx.ConnectError:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Failed to connect to storage: {str(e)}"
detail="Storage service unavailable."
)
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Waveform file not found in storage"
detail="Waveform file not found in storage."
)
else:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Storage error: {str(e)}"
detail="Storage returned an error."
)
except Exception as e:
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to fetch waveform: {str(e)}"
detail="Failed to fetch waveform."
)
import json
@@ -239,26 +239,26 @@ async def stream_version(
)
try:
data = await _download_with_retry(storage, file_path)
except httpx.ConnectError as e:
except httpx.ConnectError:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=f"Failed to connect to storage: {str(e)}"
detail="Storage service unavailable."
)
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found in storage"
detail="File not found in storage."
)
else:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Storage error: {str(e)}"
detail="Storage returned an error."
)
except Exception as e:
except Exception:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to stream file: {str(e)}"
detail="Failed to stream file."
)
content_type = _AUDIO_CONTENT_TYPES.get(Path(file_path).suffix.lower(), "application/octet-stream")

View File

@@ -2,15 +2,35 @@
import uuid
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
from rehearsalhub.repositories.member import MemberRepository
from rehearsalhub.db.engine import get_session
from rehearsalhub.services.auth import decode_token
from rehearsalhub.ws import manager
router = APIRouter(tags=["websocket"])
@router.websocket("/ws/versions/{version_id}")
async def version_ws(version_id: uuid.UUID, websocket: WebSocket):
async def version_ws(
version_id: uuid.UUID,
websocket: WebSocket,
token: str = Query(...),
):
"""WebSocket endpoint. Requires a valid JWT passed as ?token=<jwt>."""
# Validate token before accepting the connection
async for session in get_session():
try:
payload = decode_token(token)
member_id = uuid.UUID(payload["sub"])
member = await MemberRepository(session).get_by_id(member_id)
if member is None:
raise ValueError("member not found")
except Exception:
await websocket.close(code=4001)
return
await manager.connect(version_id, websocket)
try:
while True:

View File

@@ -0,0 +1,18 @@
"""Unit test fixtures — sets required env vars so Settings() loads without a .env file."""
import pytest
@pytest.fixture(autouse=True)
def patch_settings(monkeypatch):
"""Provide the minimum env vars that Settings() requires for unit tests."""
monkeypatch.setenv("SECRET_KEY", "a" * 64)
monkeypatch.setenv("INTERNAL_SECRET", "b" * 64)
monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://test:test@localhost/test")
# Clear the lru_cache so each test gets a fresh Settings instance with the
# monkeypatched env vars rather than the cached production instance.
from rehearsalhub.config import get_settings
get_settings.cache_clear()
yield
get_settings.cache_clear()

View File

@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from rehearsalhub.queue.redis_queue import RedisJobQueue
from rehearsalhub.queue.redis_queue import RedisJobQueue, flush_pending_pushes
@pytest.mark.asyncio
@@ -34,6 +34,9 @@ async def test_enqueue_creates_job_and_pushes_to_redis(mock_session):
job_id = await queue.enqueue("transcode", {"version_id": "abc"})
mock_session.add.assert_called_once()
# The Redis push is deferred; it fires when flush_pending_pushes is called after commit.
mock_redis.rpush.assert_not_called()
await flush_pending_pushes(mock_session)
mock_redis.rpush.assert_called_once()

View File

@@ -56,7 +56,7 @@ async def test_band_is_member_calls_get_member_role(mock_session):
band_id = uuid.uuid4()
member_id = uuid.uuid4()
result_mock = AsyncMock()
result_mock = MagicMock()
result_mock.scalar_one_or_none.return_value = "admin"
mock_session.execute.return_value = result_mock
@@ -70,7 +70,7 @@ async def test_band_is_member_false_when_no_role(mock_session):
band_id = uuid.uuid4()
member_id = uuid.uuid4()
result_mock = AsyncMock()
result_mock = MagicMock()
result_mock.scalar_one_or_none.return_value = None
mock_session.execute.return_value = result_mock