diff --git a/api/pyproject.toml b/api/pyproject.toml
index 59d2f6c..b583cfb 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -20,6 +20,7 @@ dependencies = [
"redis[hiredis]>=5.0",
"python-multipart>=0.0.9",
"Pillow>=10.0",
+ "slowapi>=0.1.9",
]
[project.optional-dependencies]
diff --git a/api/src/rehearsalhub/dependencies.py b/api/src/rehearsalhub/dependencies.py
index 892b07c..635a7dc 100644
--- a/api/src/rehearsalhub/dependencies.py
+++ b/api/src/rehearsalhub/dependencies.py
@@ -4,7 +4,7 @@ from __future__ import annotations
import uuid
-from fastapi import Depends, HTTPException, status
+from fastapi import Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession
@@ -13,18 +13,25 @@ from rehearsalhub.db.models import Member
from rehearsalhub.services.auth import decode_token
from rehearsalhub.repositories.member import MemberRepository
-oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
+# auto_error=False so we can fall back to cookie auth without a 401 from the scheme itself
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login", auto_error=False)
async def get_current_member(
- token: str = Depends(oauth2_scheme),
+ request: Request,
+ bearer_token: str | None = Depends(oauth2_scheme),
session: AsyncSession = Depends(get_session),
) -> Member:
+ # Prefer Authorization: Bearer header; fall back to httpOnly cookie
+ token = bearer_token or request.cookies.get("rh_token")
+
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
+ if not token:
+ raise credentials_exc
try:
payload = decode_token(token)
member_id_str: str | None = payload.get("sub")
diff --git a/api/src/rehearsalhub/main.py b/api/src/rehearsalhub/main.py
index c9f630e..6c4def4 100644
--- a/api/src/rehearsalhub/main.py
+++ b/api/src/rehearsalhub/main.py
@@ -6,6 +6,9 @@ import os
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
+from slowapi import Limiter, _rate_limit_exceeded_handler
+from slowapi.errors import RateLimitExceeded
+from slowapi.util import get_remote_address
from rehearsalhub.config import get_settings
from rehearsalhub.routers import (
@@ -20,6 +23,8 @@ from rehearsalhub.routers import (
ws_router,
)
+limiter = Limiter(key_func=get_remote_address)
+
@asynccontextmanager
async def lifespan(app: FastAPI):
@@ -43,6 +48,9 @@ def create_app() -> FastAPI:
lifespan=lifespan,
)
+ app.state.limiter = limiter
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
+
app.add_middleware(
CORSMiddleware,
allow_origins=[f"https://{settings.domain}", "http://localhost:3000"],
diff --git a/api/src/rehearsalhub/routers/auth.py b/api/src/rehearsalhub/routers/auth.py
index 90b504c..6f30a16 100644
--- a/api/src/rehearsalhub/routers/auth.py
+++ b/api/src/rehearsalhub/routers/auth.py
@@ -2,10 +2,13 @@ import logging
import os
import uuid
-from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
+from fastapi import APIRouter, Depends, File, HTTPException, Request, Response, UploadFile, status
from PIL import Image, UnidentifiedImageError
+from slowapi import Limiter
+from slowapi.util import get_remote_address
from sqlalchemy.ext.asyncio import AsyncSession
+from rehearsalhub.config import get_settings
from rehearsalhub.db.engine import get_session
from rehearsalhub.db.models import Member
from rehearsalhub.dependencies import get_current_member
@@ -17,13 +20,15 @@ from rehearsalhub.services.auth import AuthService
log = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["auth"])
+limiter = Limiter(key_func=get_remote_address)
_ALLOWED_IMAGE_EXTENSIONS = {"jpg", "jpeg", "png", "gif", "webp"}
_MAX_AVATAR_SIZE = 5 * 1024 * 1024 # 5 MB
@router.post("/register", response_model=MemberRead, status_code=status.HTTP_201_CREATED)
-async def register(req: RegisterRequest, session: AsyncSession = Depends(get_session)):
+@limiter.limit("5/minute")
+async def register(request: Request, req: RegisterRequest, session: AsyncSession = Depends(get_session)):
svc = AuthService(session)
try:
member = await svc.register(req)
@@ -33,16 +38,38 @@ async def register(req: RegisterRequest, session: AsyncSession = Depends(get_ses
@router.post("/login", response_model=TokenResponse)
-async def login(req: LoginRequest, session: AsyncSession = Depends(get_session)):
+@limiter.limit("10/minute")
+async def login(
+ request: Request,
+ req: LoginRequest,
+ response: Response,
+ session: AsyncSession = Depends(get_session),
+):
svc = AuthService(session)
token = await svc.login(req.email, req.password)
if token is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials"
)
+ settings = get_settings()
+ response.set_cookie(
+ key="rh_token",
+ value=token.access_token,
+ httponly=True,
+ secure=not settings.debug,
+ samesite="lax",
+ max_age=settings.access_token_expire_minutes * 60,
+ path="/",
+ )
return token
+@router.post("/logout", status_code=status.HTTP_204_NO_CONTENT)
+async def logout(response: Response):
+ response.delete_cookie(key="rh_token", path="/")
+ return None
+
+
@router.get("/me", response_model=MemberRead)
async def get_me(current_member: Member = Depends(get_current_member)):
return MemberRead.from_model(current_member)
diff --git a/api/src/rehearsalhub/routers/songs.py b/api/src/rehearsalhub/routers/songs.py
index 219ff00..cb0a9ee 100644
--- a/api/src/rehearsalhub/routers/songs.py
+++ b/api/src/rehearsalhub/routers/songs.py
@@ -180,9 +180,9 @@ async def scan_nextcloud_stream(
yield json.dumps(event) + "\n"
if event.get("type") in ("song", "session"):
await db.commit()
- except Exception as exc:
+ except Exception:
log.exception("SSE scan error for band %s", band_id)
- yield json.dumps({"type": "error", "message": str(exc)}) + "\n"
+ yield json.dumps({"type": "error", "message": "Scan failed due to an internal error."}) + "\n"
finally:
await db.commit()
diff --git a/api/src/rehearsalhub/routers/ws.py b/api/src/rehearsalhub/routers/ws.py
index 8759b37..b328d60 100644
--- a/api/src/rehearsalhub/routers/ws.py
+++ b/api/src/rehearsalhub/routers/ws.py
@@ -16,13 +16,19 @@ router = APIRouter(tags=["websocket"])
async def version_ws(
version_id: uuid.UUID,
websocket: WebSocket,
- token: str = Query(...),
+ token: str | None = Query(None),
):
- """WebSocket endpoint. Requires a valid JWT passed as ?token=