93 lines
3.2 KiB
Python
Executable File
93 lines
3.2 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from rehearsalhub.db.models import AudioVersion, Song
|
|
from rehearsalhub.repositories.base import BaseRepository
|
|
|
|
|
|
class SongRepository(BaseRepository[Song]):
|
|
model = Song
|
|
|
|
async def list_for_band(self, band_id: uuid.UUID) -> list[Song]:
|
|
stmt = (
|
|
select(Song)
|
|
.where(Song.band_id == band_id)
|
|
.options(selectinload(Song.versions))
|
|
.order_by(Song.updated_at.desc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def get_with_versions(self, song_id: uuid.UUID) -> Song | None:
|
|
stmt = (
|
|
select(Song)
|
|
.options(selectinload(Song.versions))
|
|
.where(Song.id == song_id)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_nc_folder_path(self, nc_folder_path: str) -> "Song | None":
|
|
stmt = select(Song).where(Song.nc_folder_path == nc_folder_path)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_title_and_band(self, band_id: uuid.UUID, title: str) -> "Song | None":
|
|
stmt = select(Song).where(Song.band_id == band_id, Song.title == title)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def search(
|
|
self,
|
|
band_id: uuid.UUID,
|
|
q: str | None = None,
|
|
tags: list[str] | None = None,
|
|
key: str | None = None,
|
|
bpm_min: float | None = None,
|
|
bpm_max: float | None = None,
|
|
session_id: uuid.UUID | None = None,
|
|
unattributed: bool = False,
|
|
) -> list[Song]:
|
|
from sqlalchemy import cast, func
|
|
from sqlalchemy.dialects.postgresql import ARRAY
|
|
from sqlalchemy import Text
|
|
|
|
stmt = (
|
|
select(Song)
|
|
.where(Song.band_id == band_id)
|
|
.options(selectinload(Song.versions))
|
|
.order_by(Song.updated_at.desc())
|
|
)
|
|
if q:
|
|
stmt = stmt.where(Song.title.ilike(f"%{q}%"))
|
|
if tags:
|
|
# songs.tags must contain ALL requested tags
|
|
stmt = stmt.where(Song.tags.contains(cast(tags, ARRAY(Text))))
|
|
if key:
|
|
stmt = stmt.where(func.lower(Song.global_key) == key.lower())
|
|
if bpm_min is not None:
|
|
stmt = stmt.where(Song.global_bpm >= bpm_min)
|
|
if bpm_max is not None:
|
|
stmt = stmt.where(Song.global_bpm <= bpm_max)
|
|
if session_id is not None:
|
|
stmt = stmt.where(Song.session_id == session_id)
|
|
if unattributed:
|
|
stmt = stmt.where(Song.session_id.is_(None))
|
|
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def next_version_number(self, song_id: uuid.UUID) -> int:
|
|
from sqlalchemy import func
|
|
|
|
stmt = select(func.coalesce(func.max(AudioVersion.version_number), 0) + 1).where(
|
|
AudioVersion.song_id == song_id
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one()
|