Files
rehearshalhub/api/src/rehearsalhub/repositories/song.py
2026-04-08 15:10:52 +02:00

93 lines
3.2 KiB
Python
Executable File

from __future__ import annotations
import uuid
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from rehearsalhub.db.models import AudioVersion, Song
from rehearsalhub.repositories.base import BaseRepository
class SongRepository(BaseRepository[Song]):
model = Song
async def list_for_band(self, band_id: uuid.UUID) -> list[Song]:
stmt = (
select(Song)
.where(Song.band_id == band_id)
.options(selectinload(Song.versions))
.order_by(Song.updated_at.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def get_with_versions(self, song_id: uuid.UUID) -> Song | None:
stmt = (
select(Song)
.options(selectinload(Song.versions))
.where(Song.id == song_id)
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_nc_folder_path(self, nc_folder_path: str) -> "Song | None":
stmt = select(Song).where(Song.nc_folder_path == nc_folder_path)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_title_and_band(self, band_id: uuid.UUID, title: str) -> "Song | None":
stmt = select(Song).where(Song.band_id == band_id, Song.title == title)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def search(
self,
band_id: uuid.UUID,
q: str | None = None,
tags: list[str] | None = None,
key: str | None = None,
bpm_min: float | None = None,
bpm_max: float | None = None,
session_id: uuid.UUID | None = None,
unattributed: bool = False,
) -> list[Song]:
from sqlalchemy import cast, func
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy import Text
stmt = (
select(Song)
.where(Song.band_id == band_id)
.options(selectinload(Song.versions))
.order_by(Song.updated_at.desc())
)
if q:
stmt = stmt.where(Song.title.ilike(f"%{q}%"))
if tags:
# songs.tags must contain ALL requested tags
stmt = stmt.where(Song.tags.contains(cast(tags, ARRAY(Text))))
if key:
stmt = stmt.where(func.lower(Song.global_key) == key.lower())
if bpm_min is not None:
stmt = stmt.where(Song.global_bpm >= bpm_min)
if bpm_max is not None:
stmt = stmt.where(Song.global_bpm <= bpm_max)
if session_id is not None:
stmt = stmt.where(Song.session_id == session_id)
if unattributed:
stmt = stmt.where(Song.session_id.is_(None))
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def next_version_number(self, song_id: uuid.UUID) -> int:
from sqlalchemy import func
stmt = select(func.coalesce(func.max(AudioVersion.version_number), 0) + 1).where(
AudioVersion.song_id == song_id
)
result = await self.session.execute(stmt)
return result.scalar_one()