from __future__ import annotations import uuid from datetime import date, datetime from sqlalchemy import func, select from sqlalchemy.orm import selectinload from rehearsalhub.db.models import RehearsalSession, Song from rehearsalhub.repositories.base import BaseRepository class RehearsalSessionRepository(BaseRepository[RehearsalSession]): model = RehearsalSession async def get_by_nc_folder(self, band_id: uuid.UUID, nc_folder_path: str) -> RehearsalSession | None: stmt = select(RehearsalSession).where( RehearsalSession.band_id == band_id, RehearsalSession.nc_folder_path == nc_folder_path, ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_by_band_and_date(self, band_id: uuid.UUID, session_date: date) -> RehearsalSession | None: # Match on date portion only (stored as DateTime(timezone=False)) day_start = datetime(session_date.year, session_date.month, session_date.day) stmt = select(RehearsalSession).where( RehearsalSession.band_id == band_id, RehearsalSession.date == day_start, ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def get_or_create( self, band_id: uuid.UUID, session_date: date, nc_folder_path: str, ) -> RehearsalSession: existing = await self.get_by_band_and_date(band_id, session_date) if existing is not None: return existing try: async with self.session.begin_nested(): return await self.create( band_id=band_id, date=datetime(session_date.year, session_date.month, session_date.day), nc_folder_path=nc_folder_path, ) except Exception: # Another request raced us — fetch the row that now exists existing = await self.get_by_band_and_date(band_id, session_date) if existing is not None: return existing raise async def list_for_band(self, band_id: uuid.UUID) -> list[tuple[RehearsalSession, int]]: """Return (session, recording_count) tuples, newest date first.""" count_col = func.count(Song.id).label("recording_count") stmt = ( select(RehearsalSession, count_col) .outerjoin(Song, Song.session_id == RehearsalSession.id) .where(RehearsalSession.band_id == band_id) .group_by(RehearsalSession.id) .order_by(RehearsalSession.date.desc()) ) result = await self.session.execute(stmt) return [(row[0], row[1]) for row in result.all()] async def get_with_songs(self, session_id: uuid.UUID) -> RehearsalSession | None: stmt = ( select(RehearsalSession) .options(selectinload(RehearsalSession.songs).selectinload(Song.versions)) .where(RehearsalSession.id == session_id) ) result = await self.session.execute(stmt) return result.scalar_one_or_none()