diff --git a/api/src/rehearsalhub/repositories/song.py b/api/src/rehearsalhub/repositories/song.py index 0524b21..1f79693 100644 --- a/api/src/rehearsalhub/repositories/song.py +++ b/api/src/rehearsalhub/repositories/song.py @@ -1,6 +1,7 @@ from __future__ import annotations import uuid +from typing import Any from sqlalchemy import select from sqlalchemy.orm import selectinload @@ -41,6 +42,43 @@ class SongRepository(BaseRepository[Song]): result = await self.session.execute(stmt) return result.scalar_one_or_none() + async def search( + self, + band_id: uuid.UUID, + q: str | None = None, + tags: list[str] | None = None, + key: str | None = None, + bpm_min: float | None = None, + bpm_max: float | None = None, + session_id: uuid.UUID | None = None, + ) -> list[Song]: + from sqlalchemy import cast, func + from sqlalchemy.dialects.postgresql import ARRAY + from sqlalchemy import Text + + stmt = ( + select(Song) + .where(Song.band_id == band_id) + .options(selectinload(Song.versions)) + .order_by(Song.updated_at.desc()) + ) + if q: + stmt = stmt.where(Song.title.ilike(f"%{q}%")) + if tags: + # songs.tags must contain ALL requested tags + stmt = stmt.where(Song.tags.contains(cast(tags, ARRAY(Text)))) + if key: + stmt = stmt.where(func.lower(Song.global_key) == key.lower()) + if bpm_min is not None: + stmt = stmt.where(Song.global_bpm >= bpm_min) + if bpm_max is not None: + stmt = stmt.where(Song.global_bpm <= bpm_max) + if session_id is not None: + stmt = stmt.where(Song.session_id == session_id) + + result = await self.session.execute(stmt) + return list(result.scalars().all()) + async def next_version_number(self, song_id: uuid.UUID) -> int: from sqlalchemy import func diff --git a/api/src/rehearsalhub/routers/songs.py b/api/src/rehearsalhub/routers/songs.py index 1363f70..68d997c 100644 --- a/api/src/rehearsalhub/routers/songs.py +++ b/api/src/rehearsalhub/routers/songs.py @@ -2,7 +2,7 @@ import logging import uuid from pathlib import Path -from fastapi import APIRouter, Depends, HTTPException, status +from fastapi import APIRouter, Depends, HTTPException, Query, status from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession @@ -14,7 +14,7 @@ from rehearsalhub.repositories.band import BandRepository from rehearsalhub.repositories.comment import CommentRepository from rehearsalhub.repositories.song import SongRepository from rehearsalhub.schemas.comment import SongCommentCreate, SongCommentRead -from rehearsalhub.schemas.song import SongCreate, SongRead +from rehearsalhub.schemas.song import SongCreate, SongRead, SongUpdate from rehearsalhub.services.band import BandService from rehearsalhub.services.song import SongService from rehearsalhub.storage.nextcloud import NextcloudClient @@ -49,6 +49,65 @@ async def list_songs( return await song_svc.list_songs(band_id) +@router.get("/bands/{band_id}/songs/search", response_model=list[SongRead]) +async def search_songs( + band_id: uuid.UUID, + q: str | None = Query(None, description="Title substring search"), + tags: list[str] = Query(default=[], description="Songs must have ALL these tags"), + key: str | None = Query(None, description="Musical key, e.g. 'Am' or 'C'"), + bpm_min: float | None = Query(None, ge=0), + bpm_max: float | None = Query(None, ge=0), + session_id: uuid.UUID | None = Query(None), + session: AsyncSession = Depends(get_session), + current_member: Member = Depends(get_current_member), +): + band_svc = BandService(session) + try: + await band_svc.assert_membership(band_id, current_member.id) + except PermissionError: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a member") + + song_repo = SongRepository(session) + songs = await song_repo.search( + band_id, + q=q, + tags=tags or None, + key=key, + bpm_min=bpm_min, + bpm_max=bpm_max, + session_id=session_id, + ) + return [ + SongRead.model_validate(s, update={"version_count": len(s.versions)}) + for s in songs + ] + + +@router.patch("/songs/{song_id}", response_model=SongRead) +async def update_song( + song_id: uuid.UUID, + data: SongUpdate, + session: AsyncSession = Depends(get_session), + current_member: Member = Depends(get_current_member), +): + song_repo = SongRepository(session) + song = await song_repo.get_with_versions(song_id) + if song is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Song not found") + + band_svc = BandService(session) + try: + await band_svc.assert_membership(song.band_id, current_member.id) + except PermissionError: + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not a member") + + updates = {k: v for k, v in data.model_dump().items() if v is not None} + if updates: + song = await song_repo.update(song, **updates) + + return SongRead.model_validate(song, update={"version_count": len(song.versions)}) + + @router.post("/bands/{band_id}/songs", response_model=SongRead, status_code=status.HTTP_201_CREATED) async def create_song( band_id: uuid.UUID,