Files
rehearshalhub/api/src/rehearsalhub/repositories/annotation.py
2026-04-08 15:10:52 +02:00

114 lines
4.0 KiB
Python
Executable File

from __future__ import annotations
import uuid
from typing import Any
from sqlalchemy import and_, select
from sqlalchemy.orm import selectinload
from rehearsalhub.db.models import Annotation, RangeAnalysis
from rehearsalhub.repositories.base import BaseRepository
class AnnotationRepository(BaseRepository[Annotation]):
model = Annotation
async def list_for_version(self, version_id: uuid.UUID) -> list[Annotation]:
stmt = (
select(Annotation)
.where(
Annotation.version_id == version_id,
Annotation.deleted_at.is_(None),
)
.options(
selectinload(Annotation.range_analysis),
selectinload(Annotation.reactions),
selectinload(Annotation.author),
)
.order_by(Annotation.timestamp_ms)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def soft_delete(self, annotation: Annotation) -> None:
from datetime import datetime, timezone
annotation.deleted_at = datetime.now(timezone.utc)
await self.session.flush()
async def search_ranges(
self,
band_id: uuid.UUID,
bpm_min: float | None = None,
bpm_max: float | None = None,
key: str | None = None,
tag: str | None = None,
min_duration_ms: int | None = None,
) -> list[dict[str, Any]]:
from rehearsalhub.db.models import AudioVersion, RangeAnalysis, Song
conditions = [
Song.band_id == band_id,
Annotation.type == "range",
Annotation.deleted_at.is_(None),
]
if bpm_min is not None:
conditions.append(RangeAnalysis.bpm >= bpm_min)
if bpm_max is not None:
conditions.append(RangeAnalysis.bpm <= bpm_max)
if key is not None:
conditions.append(RangeAnalysis.key.ilike(f"%{key}%"))
if tag is not None:
conditions.append(Annotation.tags.any(tag))
if min_duration_ms is not None:
conditions.append(
(Annotation.range_end_ms - Annotation.timestamp_ms) >= min_duration_ms
)
stmt = (
select(
Annotation.id.label("annotation_id"),
Song.title.label("song_title"),
Song.id.label("song_id"),
AudioVersion.id.label("version_id"),
AudioVersion.label.label("version_label"),
Annotation.timestamp_ms.label("start_ms"),
Annotation.range_end_ms.label("end_ms"),
Annotation.label.label("label"),
Annotation.tags.label("tags"),
RangeAnalysis.bpm,
RangeAnalysis.key,
RangeAnalysis.scale,
RangeAnalysis.avg_loudness_lufs,
RangeAnalysis.energy,
)
.join(AudioVersion, Annotation.version_id == AudioVersion.id)
.join(Song, AudioVersion.song_id == Song.id)
.join(RangeAnalysis, RangeAnalysis.annotation_id == Annotation.id)
.where(and_(*conditions))
.order_by(Annotation.timestamp_ms)
)
result = await self.session.execute(stmt)
return [row._asdict() for row in result]
async def list_all_ranges_for_band(self, band_id: uuid.UUID) -> list[Annotation]:
from rehearsalhub.db.models import AudioVersion, Song
stmt = (
select(Annotation)
.join(AudioVersion, Annotation.version_id == AudioVersion.id)
.join(Song, AudioVersion.song_id == Song.id)
.where(
Song.band_id == band_id,
Annotation.type == "range",
Annotation.deleted_at.is_(None),
)
.options(
selectinload(Annotation.range_analysis),
selectinload(Annotation.author),
)
.order_by(Annotation.created_at.desc())
)
result = await self.session.execute(stmt)
return list(result.scalars().all())