Files
rehearshalhub/api/src/rehearsalhub/main.py
2026-04-08 15:10:52 +02:00

112 lines
3.5 KiB
Python
Executable File

"""RehearsalHub FastAPI application entry point."""
from contextlib import asynccontextmanager
import os
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from rehearsalhub.config import get_settings
from rehearsalhub.routers import (
annotations_router,
auth_router,
bands_router,
invites_router,
internal_router,
members_router,
sessions_router,
songs_router,
versions_router,
ws_router,
)
limiter = Limiter(key_func=get_remote_address)
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
# Clean up DB connections on shutdown
from rehearsalhub.db.engine import get_engine
engine = get_engine()
await engine.dispose()
def create_app() -> FastAPI:
settings = get_settings()
app = FastAPI(
title="RehearsalHub API",
version="0.1.0",
docs_url="/api/docs",
redoc_url="/api/redoc",
openapi_url="/api/openapi.json",
lifespan=lifespan,
)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
# Get allowed origins from environment or use defaults
allowed_origins = [f"https://{settings.domain}", "http://localhost:3000"]
# Add specific domain for production
if settings.domain != "localhost":
allowed_origins.extend([
f"https://{settings.domain}",
f"http://{settings.domain}",
])
# Add additional CORS origins from environment variable
if settings.cors_origins:
additional_origins = [origin.strip() for origin in settings.cors_origins.split(",")]
allowed_origins.extend(additional_origins)
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "PATCH", "DELETE"],
allow_headers=["Authorization", "Content-Type", "Accept"],
)
@app.middleware("http")
async def security_headers(request: Request, call_next) -> Response:
response = await call_next(request)
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["X-XSS-Protection"] = "0" # Disable legacy XSS auditor; rely on CSP
return response
prefix = "/api/v1"
app.include_router(auth_router, prefix=prefix)
app.include_router(bands_router, prefix=prefix)
app.include_router(invites_router, prefix=prefix)
app.include_router(sessions_router, prefix=prefix)
app.include_router(songs_router, prefix=prefix)
app.include_router(versions_router, prefix=prefix)
app.include_router(annotations_router, prefix=prefix)
app.include_router(members_router, prefix=prefix)
app.include_router(internal_router, prefix=prefix)
app.include_router(ws_router) # WebSocket routes don't use /api/v1 prefix
@app.get("/api/health")
async def health():
return {"status": "ok"}
# Mount static files for avatar uploads
upload_dir = "uploads/avatars"
os.makedirs(upload_dir, exist_ok=True)
app.mount("/api/static/avatars", StaticFiles(directory=upload_dir), name="avatars")
return app
app = create_app()