from __future__ import annotations import uuid from typing import Any from sqlalchemy import select from sqlalchemy.orm import selectinload from rehearsalhub.db.models import AudioVersion, Song from rehearsalhub.repositories.base import BaseRepository class SongRepository(BaseRepository[Song]): model = Song async def list_for_band(self, band_id: uuid.UUID) -> list[Song]: stmt = ( select(Song) .where(Song.band_id == band_id) .options(selectinload(Song.versions)) .order_by(Song.updated_at.desc()) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def get_with_versions(self, song_id: uuid.UUID) -> Song | None: stmt = ( select(Song) .options(selectinload(Song.versions)) .where(Song.id == song_id) ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_nc_folder_path(self, nc_folder_path: str) -> "Song | None": stmt = select(Song).where(Song.nc_folder_path == nc_folder_path) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_title_and_band(self, band_id: uuid.UUID, title: str) -> "Song | None": stmt = select(Song).where(Song.band_id == band_id, Song.title == title) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def search( self, band_id: uuid.UUID, q: str | None = None, tags: list[str] | None = None, key: str | None = None, bpm_min: float | None = None, bpm_max: float | None = None, session_id: uuid.UUID | None = None, unattributed: bool = False, ) -> list[Song]: from sqlalchemy import cast, func from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy import Text stmt = ( select(Song) .where(Song.band_id == band_id) .options(selectinload(Song.versions)) .order_by(Song.updated_at.desc()) ) if q: stmt = stmt.where(Song.title.ilike(f"%{q}%")) if tags: # songs.tags must contain ALL requested tags stmt = stmt.where(Song.tags.contains(cast(tags, ARRAY(Text)))) if key: stmt = stmt.where(func.lower(Song.global_key) == key.lower()) if bpm_min is not None: stmt = stmt.where(Song.global_bpm >= bpm_min) if bpm_max is not None: stmt = stmt.where(Song.global_bpm <= bpm_max) if session_id is not None: stmt = stmt.where(Song.session_id == session_id) if unattributed: stmt = stmt.where(Song.session_id.is_(None)) result = await self.session.execute(stmt) return list(result.scalars().all()) async def next_version_number(self, song_id: uuid.UUID) -> int: from sqlalchemy import func stmt = select(func.coalesce(func.max(AudioVersion.version_number), 0) + 1).where( AudioVersion.song_id == song_id ) result = await self.session.execute(stmt) return result.scalar_one()