security: fix auth, CORS, file upload, endpoint hardening + test fixes
- Add INTERNAL_SECRET shared-secret auth to /internal/nc-upload endpoint
- Add JWT token validation to WebSocket /ws/versions/{version_id}
- Fix NameError: band_slug → band.slug in internal.py
- Move inline imports to top of internal.py; add missing Member/NextcloudClient imports
- Remove ~15 debug print() statements from auth.py
- Replace Content-Type-only avatar check with extension whitelist + Pillow Image.verify()
- Sanitize exception details in versions.py (no more str(e) in 4xx/5xx responses)
- Restrict CORS allow_methods/allow_headers from "*" to explicit lists
- Add security headers middleware: X-Frame-Options, X-Content-Type-Options, Referrer-Policy
- Reduce JWT expiry from 7 days to 1 hour
- Add Pillow>=10.0 dependency; document INTERNAL_SECRET in .env.example
- Implement missing RedisJobQueue.dequeue() method (required by protocol)
- Fix 5 pre-existing unit test failures: settings env vars conftest, deferred Redis push,
dequeue method, AsyncMock→MagicMock for sync scalar_one_or_none
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
# ── Security ──────────────────────────────────────────────────────────────────
|
# ── Security ──────────────────────────────────────────────────────────────────
|
||||||
# Generate with: openssl rand -hex 32
|
# Generate with: openssl rand -hex 32
|
||||||
SECRET_KEY=replace_me_with_32_byte_hex
|
SECRET_KEY=replace_me_with_32_byte_hex
|
||||||
|
# Shared secret for internal service-to-service calls (nc-watcher → API)
|
||||||
|
# Generate with: openssl rand -hex 32
|
||||||
|
INTERNAL_SECRET=replace_me_with_32_byte_hex
|
||||||
|
|
||||||
# ── Domain ────────────────────────────────────────────────────────────────────
|
# ── Domain ────────────────────────────────────────────────────────────────────
|
||||||
DOMAIN=yourdomain.com
|
DOMAIN=yourdomain.com
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ dependencies = [
|
|||||||
"httpx>=0.27",
|
"httpx>=0.27",
|
||||||
"redis[hiredis]>=5.0",
|
"redis[hiredis]>=5.0",
|
||||||
"python-multipart>=0.0.9",
|
"python-multipart>=0.0.9",
|
||||||
|
"Pillow>=10.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
# Security
|
# Security
|
||||||
secret_key: str
|
secret_key: str
|
||||||
|
internal_secret: str # Shared secret for internal service-to-service calls
|
||||||
jwt_algorithm: str = "HS256"
|
jwt_algorithm: str = "HS256"
|
||||||
access_token_expire_minutes: int = 60 * 24 * 7 # 7 days
|
access_token_expire_minutes: int = 60 # 1 hour
|
||||||
|
|
||||||
# Database
|
# Database
|
||||||
database_url: str # postgresql+asyncpg://...
|
database_url: str # postgresql+asyncpg://...
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request, Response
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
|
||||||
@@ -47,10 +47,19 @@ def create_app() -> FastAPI:
|
|||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=[f"https://{settings.domain}", "http://localhost:3000"],
|
allow_origins=[f"https://{settings.domain}", "http://localhost:3000"],
|
||||||
allow_credentials=True,
|
allow_credentials=True,
|
||||||
allow_methods=["*"],
|
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
|
||||||
allow_headers=["*"],
|
allow_headers=["Authorization", "Content-Type", "Accept"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def security_headers(request: Request, call_next) -> Response:
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["X-Frame-Options"] = "DENY"
|
||||||
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||||
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||||
|
response.headers["X-XSS-Protection"] = "0" # Disable legacy XSS auditor; rely on CSP
|
||||||
|
return response
|
||||||
|
|
||||||
prefix = "/api/v1"
|
prefix = "/api/v1"
|
||||||
app.include_router(auth_router, prefix=prefix)
|
app.include_router(auth_router, prefix=prefix)
|
||||||
app.include_router(bands_router, prefix=prefix)
|
app.include_router(bands_router, prefix=prefix)
|
||||||
|
|||||||
@@ -79,6 +79,20 @@ class RedisJobQueue:
|
|||||||
job.finished_at = datetime.now(timezone.utc)
|
job.finished_at = datetime.now(timezone.utc)
|
||||||
await self._session.flush()
|
await self._session.flush()
|
||||||
|
|
||||||
|
async def dequeue(self, timeout: int = 5) -> tuple[uuid.UUID, str, dict[str, Any]] | None:
|
||||||
|
"""Block up to `timeout` seconds for a job. Returns (id, type, payload) or None."""
|
||||||
|
redis_client = await self._get_redis()
|
||||||
|
queue_key = get_settings().job_queue_key
|
||||||
|
result = await redis_client.blpop(queue_key, timeout=timeout)
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
_, raw_id = result
|
||||||
|
job_id = uuid.UUID(raw_id)
|
||||||
|
job = await self._session.get(Job, job_id)
|
||||||
|
if job is None:
|
||||||
|
return None
|
||||||
|
return job_id, job.type, job.payload # type: ignore[return-value]
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
if self._redis:
|
if self._redis:
|
||||||
await self._redis.aclose()
|
await self._redis.aclose()
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
|
import logging
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||||
|
from PIL import Image, UnidentifiedImageError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from rehearsalhub.db.engine import get_session
|
from rehearsalhub.db.engine import get_session
|
||||||
from rehearsalhub.db.models import Member
|
from rehearsalhub.db.models import Member
|
||||||
from rehearsalhub.dependencies import get_current_member
|
from rehearsalhub.dependencies import get_current_member
|
||||||
@@ -11,8 +14,13 @@ from rehearsalhub.schemas.auth import LoginRequest, RegisterRequest, TokenRespon
|
|||||||
from rehearsalhub.schemas.member import MemberRead, MemberSettingsUpdate
|
from rehearsalhub.schemas.member import MemberRead, MemberSettingsUpdate
|
||||||
from rehearsalhub.services.auth import AuthService
|
from rehearsalhub.services.auth import AuthService
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
_ALLOWED_IMAGE_EXTENSIONS = {"jpg", "jpeg", "png", "gif", "webp"}
|
||||||
|
_MAX_AVATAR_SIZE = 5 * 1024 * 1024 # 5 MB
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=MemberRead, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=MemberRead, status_code=status.HTTP_201_CREATED)
|
||||||
async def register(req: RegisterRequest, session: AsyncSession = Depends(get_session)):
|
async def register(req: RegisterRequest, session: AsyncSession = Depends(get_session)):
|
||||||
@@ -46,9 +54,6 @@ async def update_settings(
|
|||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
current_member: Member = Depends(get_current_member),
|
current_member: Member = Depends(get_current_member),
|
||||||
):
|
):
|
||||||
print(f"Update settings called for member {current_member.id}")
|
|
||||||
print(f"Update data: {data.model_dump()}")
|
|
||||||
|
|
||||||
repo = MemberRepository(session)
|
repo = MemberRepository(session)
|
||||||
updates: dict = {}
|
updates: dict = {}
|
||||||
if data.display_name is not None:
|
if data.display_name is not None:
|
||||||
@@ -62,14 +67,10 @@ async def update_settings(
|
|||||||
if data.avatar_url is not None:
|
if data.avatar_url is not None:
|
||||||
updates["avatar_url"] = data.avatar_url or None
|
updates["avatar_url"] = data.avatar_url or None
|
||||||
|
|
||||||
print(f"Updates to apply: {updates}")
|
|
||||||
|
|
||||||
if updates:
|
if updates:
|
||||||
member = await repo.update(current_member, **updates)
|
member = await repo.update(current_member, **updates)
|
||||||
print("Settings updated successfully")
|
|
||||||
else:
|
else:
|
||||||
member = current_member
|
member = current_member
|
||||||
print("No updates to apply")
|
|
||||||
return MemberRead.from_model(member)
|
return MemberRead.from_model(member)
|
||||||
|
|
||||||
|
|
||||||
@@ -80,85 +81,68 @@ async def upload_avatar(
|
|||||||
current_member: Member = Depends(get_current_member),
|
current_member: Member = Depends(get_current_member),
|
||||||
):
|
):
|
||||||
"""Upload and set user avatar image."""
|
"""Upload and set user avatar image."""
|
||||||
print(f"Avatar upload called for member {current_member.id}")
|
# Validate extension against whitelist
|
||||||
print(f"File: {file.filename}, Content-Type: {file.content_type}")
|
raw_ext = (file.filename.rsplit(".", 1)[-1] if "." in (file.filename or "") else "").lower()
|
||||||
|
if raw_ext not in _ALLOWED_IMAGE_EXTENSIONS:
|
||||||
# Validate file type
|
|
||||||
if not file.content_type.startswith("image/"):
|
|
||||||
print("Invalid file type")
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
detail="Only image files are allowed (JPG, PNG, GIF, etc.)"
|
detail="Only JPG, PNG, GIF, and WebP images are allowed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate file size (5MB limit for upload endpoint)
|
# Validate file size
|
||||||
max_size = 5 * 1024 * 1024 # 5MB
|
if file.size and file.size > _MAX_AVATAR_SIZE:
|
||||||
if file.size > max_size:
|
|
||||||
print(f"File too large: {file.size} bytes (max {max_size})")
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||||
detail=f"File too large. Maximum size is {max_size / 1024 / 1024}MB. Please resize your image and try again."
|
detail=f"File too large. Maximum size is {_MAX_AVATAR_SIZE // 1024 // 1024} MB.",
|
||||||
|
)
|
||||||
|
|
||||||
|
contents = await file.read()
|
||||||
|
|
||||||
|
if len(contents) > _MAX_AVATAR_SIZE:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||||
|
detail=f"File too large. Maximum size is {_MAX_AVATAR_SIZE // 1024 // 1024} MB.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(contents) == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="Empty file received.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate that the file is actually a valid image using Pillow
|
||||||
|
try:
|
||||||
|
import io
|
||||||
|
img = Image.open(io.BytesIO(contents))
|
||||||
|
img.verify()
|
||||||
|
except UnidentifiedImageError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="File is not a valid image.",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||||
|
detail="Could not process image file.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create uploads directory if it doesn't exist
|
|
||||||
upload_dir = "uploads/avatars"
|
upload_dir = "uploads/avatars"
|
||||||
os.makedirs(upload_dir, exist_ok=True)
|
os.makedirs(upload_dir, exist_ok=True)
|
||||||
print(f"Using upload directory: {upload_dir}")
|
|
||||||
|
|
||||||
# Generate unique filename
|
filename = f"{uuid.uuid4()}.{raw_ext}"
|
||||||
file_ext = file.filename.split(".")[-1] if "." in file.filename else "jpg"
|
|
||||||
filename = f"{uuid.uuid4()}.{file_ext}"
|
|
||||||
file_path = f"{upload_dir}/{filename}"
|
file_path = f"{upload_dir}/{filename}"
|
||||||
|
|
||||||
print(f"Saving file to: {file_path}")
|
|
||||||
|
|
||||||
# Save file
|
|
||||||
try:
|
try:
|
||||||
contents = await file.read()
|
|
||||||
print(f"File size: {len(contents)} bytes")
|
|
||||||
print(f"File content preview: {contents[:50]}...") # First 50 bytes for debugging
|
|
||||||
|
|
||||||
# Validate that we actually got content
|
|
||||||
if len(contents) == 0:
|
|
||||||
print("Empty file content received")
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
||||||
detail="Empty file content received"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(file_path, "wb") as buffer:
|
with open(file_path, "wb") as buffer:
|
||||||
buffer.write(contents)
|
buffer.write(contents)
|
||||||
print("File saved successfully")
|
except Exception:
|
||||||
|
log.exception("Failed to save avatar for member %s", current_member.id)
|
||||||
# Verify file was saved
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
detail="Failed to verify saved file"
|
|
||||||
)
|
|
||||||
|
|
||||||
file_size = os.path.getsize(file_path)
|
|
||||||
print(f"Saved file size: {file_size} bytes")
|
|
||||||
|
|
||||||
if file_size == 0:
|
|
||||||
os.remove(file_path)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
||||||
detail="Saved file is empty"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to save file: {str(e)}")
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to save avatar: {str(e)}"
|
detail="Failed to save avatar.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update member's avatar URL
|
|
||||||
repo = MemberRepository(session)
|
repo = MemberRepository(session)
|
||||||
avatar_url = f"/api/static/avatars/{filename}"
|
avatar_url = f"/api/static/avatars/{filename}"
|
||||||
print(f"Setting avatar URL to: {avatar_url}")
|
|
||||||
member = await repo.update(current_member, avatar_url=avatar_url)
|
member = await repo.update(current_member, avatar_url=avatar_url)
|
||||||
print("Avatar updated successfully")
|
|
||||||
|
|
||||||
return MemberRead.from_model(member)
|
return MemberRead.from_model(member)
|
||||||
|
|||||||
@@ -3,11 +3,14 @@
|
|||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from rehearsalhub.config import get_settings
|
||||||
from rehearsalhub.db.engine import get_session
|
from rehearsalhub.db.engine import get_session
|
||||||
|
from rehearsalhub.db.models import BandMember, Member
|
||||||
from rehearsalhub.repositories.audio_version import AudioVersionRepository
|
from rehearsalhub.repositories.audio_version import AudioVersionRepository
|
||||||
from rehearsalhub.repositories.band import BandRepository
|
from rehearsalhub.repositories.band import BandRepository
|
||||||
from rehearsalhub.repositories.rehearsal_session import RehearsalSessionRepository
|
from rehearsalhub.repositories.rehearsal_session import RehearsalSessionRepository
|
||||||
@@ -15,11 +18,19 @@ from rehearsalhub.repositories.song import SongRepository
|
|||||||
from rehearsalhub.schemas.audio_version import AudioVersionCreate
|
from rehearsalhub.schemas.audio_version import AudioVersionCreate
|
||||||
from rehearsalhub.services.session import extract_session_folder, parse_rehearsal_date
|
from rehearsalhub.services.session import extract_session_folder, parse_rehearsal_date
|
||||||
from rehearsalhub.services.song import SongService
|
from rehearsalhub.services.song import SongService
|
||||||
|
from rehearsalhub.storage.nextcloud import NextcloudClient
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/internal", tags=["internal"])
|
router = APIRouter(prefix="/internal", tags=["internal"])
|
||||||
|
|
||||||
|
|
||||||
|
async def _verify_internal_secret(x_internal_token: str | None = Header(None)) -> None:
|
||||||
|
"""Verify the shared secret sent by internal services."""
|
||||||
|
settings = get_settings()
|
||||||
|
if x_internal_token != settings.internal_secret:
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||||
|
|
||||||
AUDIO_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".opus"}
|
AUDIO_EXTENSIONS = {".mp3", ".wav", ".flac", ".ogg", ".m4a", ".aac", ".opus"}
|
||||||
|
|
||||||
|
|
||||||
@@ -32,6 +43,7 @@ class NcUploadEvent(BaseModel):
|
|||||||
async def nc_upload(
|
async def nc_upload(
|
||||||
event: NcUploadEvent,
|
event: NcUploadEvent,
|
||||||
session: AsyncSession = Depends(get_session),
|
session: AsyncSession = Depends(get_session),
|
||||||
|
_: None = Depends(_verify_internal_secret),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Called by nc-watcher when a new audio file is detected in Nextcloud.
|
Called by nc-watcher when a new audio file is detected in Nextcloud.
|
||||||
@@ -105,13 +117,11 @@ async def nc_upload(
|
|||||||
nc_folder_path=nc_folder,
|
nc_folder_path=nc_folder,
|
||||||
created_by=None,
|
created_by=None,
|
||||||
)
|
)
|
||||||
log.info("nc-upload: created song '%s' for band '%s'", title, band_slug)
|
log.info("nc-upload: created song '%s' for band '%s'", title, band.slug)
|
||||||
elif rehearsal_session_id and song.session_id is None:
|
elif rehearsal_session_id and song.session_id is None:
|
||||||
song = await song_repo.update(song, session_id=rehearsal_session_id)
|
song = await song_repo.update(song, session_id=rehearsal_session_id)
|
||||||
|
|
||||||
# Use first member of the band as uploader (best-effort for watcher uploads)
|
# Use first member of the band as uploader (best-effort for watcher uploads)
|
||||||
from sqlalchemy import select
|
|
||||||
from rehearsalhub.db.models import BandMember
|
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(BandMember.member_id).where(BandMember.band_id == band.id).limit(1)
|
select(BandMember.member_id).where(BandMember.band_id == band.id).limit(1)
|
||||||
)
|
)
|
||||||
@@ -121,7 +131,7 @@ async def nc_upload(
|
|||||||
storage = None
|
storage = None
|
||||||
if uploader_id:
|
if uploader_id:
|
||||||
uploader_result = await session.execute(
|
uploader_result = await session.execute(
|
||||||
select(Member).where(Member.id == uploader_id).limit(1)
|
select(Member).where(Member.id == uploader_id).limit(1) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
uploader = uploader_result.scalar_one_or_none()
|
uploader = uploader_result.scalar_one_or_none()
|
||||||
storage = NextcloudClient.for_member(uploader) if uploader else None
|
storage = NextcloudClient.for_member(uploader) if uploader else None
|
||||||
|
|||||||
@@ -189,26 +189,26 @@ async def get_waveform(
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
data = await _download_with_retry(storage, version.waveform_url)
|
data = await _download_with_retry(storage, version.waveform_url)
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail=f"Failed to connect to storage: {str(e)}"
|
detail="Storage service unavailable."
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
if e.response.status_code == 404:
|
if e.response.status_code == 404:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="Waveform file not found in storage"
|
detail="Waveform file not found in storage."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"Storage error: {str(e)}"
|
detail="Storage returned an error."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to fetch waveform: {str(e)}"
|
detail="Failed to fetch waveform."
|
||||||
)
|
)
|
||||||
import json
|
import json
|
||||||
|
|
||||||
@@ -239,26 +239,26 @@ async def stream_version(
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
data = await _download_with_retry(storage, file_path)
|
data = await _download_with_retry(storage, file_path)
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
detail=f"Failed to connect to storage: {str(e)}"
|
detail="Storage service unavailable."
|
||||||
)
|
)
|
||||||
except httpx.HTTPStatusError as e:
|
except httpx.HTTPStatusError as e:
|
||||||
if e.response.status_code == 404:
|
if e.response.status_code == 404:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="File not found in storage"
|
detail="File not found in storage."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||||
detail=f"Storage error: {str(e)}"
|
detail="Storage returned an error."
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Failed to stream file: {str(e)}"
|
detail="Failed to stream file."
|
||||||
)
|
)
|
||||||
|
|
||||||
content_type = _AUDIO_CONTENT_TYPES.get(Path(file_path).suffix.lower(), "application/octet-stream")
|
content_type = _AUDIO_CONTENT_TYPES.get(Path(file_path).suffix.lower(), "application/octet-stream")
|
||||||
|
|||||||
@@ -2,15 +2,35 @@
|
|||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
from fastapi import APIRouter, Query, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
|
from rehearsalhub.repositories.member import MemberRepository
|
||||||
|
from rehearsalhub.db.engine import get_session
|
||||||
|
from rehearsalhub.services.auth import decode_token
|
||||||
from rehearsalhub.ws import manager
|
from rehearsalhub.ws import manager
|
||||||
|
|
||||||
router = APIRouter(tags=["websocket"])
|
router = APIRouter(tags=["websocket"])
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/ws/versions/{version_id}")
|
@router.websocket("/ws/versions/{version_id}")
|
||||||
async def version_ws(version_id: uuid.UUID, websocket: WebSocket):
|
async def version_ws(
|
||||||
|
version_id: uuid.UUID,
|
||||||
|
websocket: WebSocket,
|
||||||
|
token: str = Query(...),
|
||||||
|
):
|
||||||
|
"""WebSocket endpoint. Requires a valid JWT passed as ?token=<jwt>."""
|
||||||
|
# Validate token before accepting the connection
|
||||||
|
async for session in get_session():
|
||||||
|
try:
|
||||||
|
payload = decode_token(token)
|
||||||
|
member_id = uuid.UUID(payload["sub"])
|
||||||
|
member = await MemberRepository(session).get_by_id(member_id)
|
||||||
|
if member is None:
|
||||||
|
raise ValueError("member not found")
|
||||||
|
except Exception:
|
||||||
|
await websocket.close(code=4001)
|
||||||
|
return
|
||||||
|
|
||||||
await manager.connect(version_id, websocket)
|
await manager.connect(version_id, websocket)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
18
api/tests/unit/conftest.py
Normal file
18
api/tests/unit/conftest.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
"""Unit test fixtures — sets required env vars so Settings() loads without a .env file."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_settings(monkeypatch):
|
||||||
|
"""Provide the minimum env vars that Settings() requires for unit tests."""
|
||||||
|
monkeypatch.setenv("SECRET_KEY", "a" * 64)
|
||||||
|
monkeypatch.setenv("INTERNAL_SECRET", "b" * 64)
|
||||||
|
monkeypatch.setenv("DATABASE_URL", "postgresql+asyncpg://test:test@localhost/test")
|
||||||
|
|
||||||
|
# Clear the lru_cache so each test gets a fresh Settings instance with the
|
||||||
|
# monkeypatched env vars rather than the cached production instance.
|
||||||
|
from rehearsalhub.config import get_settings
|
||||||
|
get_settings.cache_clear()
|
||||||
|
yield
|
||||||
|
get_settings.cache_clear()
|
||||||
@@ -5,7 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from rehearsalhub.queue.redis_queue import RedisJobQueue
|
from rehearsalhub.queue.redis_queue import RedisJobQueue, flush_pending_pushes
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -34,6 +34,9 @@ async def test_enqueue_creates_job_and_pushes_to_redis(mock_session):
|
|||||||
job_id = await queue.enqueue("transcode", {"version_id": "abc"})
|
job_id = await queue.enqueue("transcode", {"version_id": "abc"})
|
||||||
|
|
||||||
mock_session.add.assert_called_once()
|
mock_session.add.assert_called_once()
|
||||||
|
# The Redis push is deferred; it fires when flush_pending_pushes is called after commit.
|
||||||
|
mock_redis.rpush.assert_not_called()
|
||||||
|
await flush_pending_pushes(mock_session)
|
||||||
mock_redis.rpush.assert_called_once()
|
mock_redis.rpush.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ async def test_band_is_member_calls_get_member_role(mock_session):
|
|||||||
band_id = uuid.uuid4()
|
band_id = uuid.uuid4()
|
||||||
member_id = uuid.uuid4()
|
member_id = uuid.uuid4()
|
||||||
|
|
||||||
result_mock = AsyncMock()
|
result_mock = MagicMock()
|
||||||
result_mock.scalar_one_or_none.return_value = "admin"
|
result_mock.scalar_one_or_none.return_value = "admin"
|
||||||
mock_session.execute.return_value = result_mock
|
mock_session.execute.return_value = result_mock
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ async def test_band_is_member_false_when_no_role(mock_session):
|
|||||||
band_id = uuid.uuid4()
|
band_id = uuid.uuid4()
|
||||||
member_id = uuid.uuid4()
|
member_id = uuid.uuid4()
|
||||||
|
|
||||||
result_mock = AsyncMock()
|
result_mock = MagicMock()
|
||||||
result_mock.scalar_one_or_none.return_value = None
|
result_mock.scalar_one_or_none.return_value = None
|
||||||
mock_session.execute.return_value = result_mock
|
mock_session.execute.return_value = result_mock
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user