from __future__ import annotations import uuid from typing import Any from sqlalchemy import and_, select from sqlalchemy.orm import selectinload from rehearsalhub.db.models import Annotation, RangeAnalysis from rehearsalhub.repositories.base import BaseRepository class AnnotationRepository(BaseRepository[Annotation]): model = Annotation async def list_for_version(self, version_id: uuid.UUID) -> list[Annotation]: stmt = ( select(Annotation) .where( Annotation.version_id == version_id, Annotation.deleted_at.is_(None), ) .options( selectinload(Annotation.range_analysis), selectinload(Annotation.reactions), selectinload(Annotation.author), ) .order_by(Annotation.timestamp_ms) ) result = await self.session.execute(stmt) return list(result.scalars().all()) async def soft_delete(self, annotation: Annotation) -> None: from datetime import datetime, timezone annotation.deleted_at = datetime.now(timezone.utc) await self.session.flush() async def search_ranges( self, band_id: uuid.UUID, bpm_min: float | None = None, bpm_max: float | None = None, key: str | None = None, tag: str | None = None, min_duration_ms: int | None = None, ) -> list[dict[str, Any]]: from rehearsalhub.db.models import AudioVersion, RangeAnalysis, Song conditions = [ Song.band_id == band_id, Annotation.type == "range", Annotation.deleted_at.is_(None), ] if bpm_min is not None: conditions.append(RangeAnalysis.bpm >= bpm_min) if bpm_max is not None: conditions.append(RangeAnalysis.bpm <= bpm_max) if key is not None: conditions.append(RangeAnalysis.key.ilike(f"%{key}%")) if tag is not None: conditions.append(Annotation.tags.any(tag)) if min_duration_ms is not None: conditions.append( (Annotation.range_end_ms - Annotation.timestamp_ms) >= min_duration_ms ) stmt = ( select( Annotation.id.label("annotation_id"), Song.title.label("song_title"), Song.id.label("song_id"), AudioVersion.id.label("version_id"), AudioVersion.label.label("version_label"), Annotation.timestamp_ms.label("start_ms"), Annotation.range_end_ms.label("end_ms"), Annotation.label.label("label"), Annotation.tags.label("tags"), RangeAnalysis.bpm, RangeAnalysis.key, RangeAnalysis.scale, RangeAnalysis.avg_loudness_lufs, RangeAnalysis.energy, ) .join(AudioVersion, Annotation.version_id == AudioVersion.id) .join(Song, AudioVersion.song_id == Song.id) .join(RangeAnalysis, RangeAnalysis.annotation_id == Annotation.id) .where(and_(*conditions)) .order_by(Annotation.timestamp_ms) ) result = await self.session.execute(stmt) return [row._asdict() for row in result] async def list_all_ranges_for_band(self, band_id: uuid.UUID) -> list[Annotation]: from rehearsalhub.db.models import AudioVersion, Song stmt = ( select(Annotation) .join(AudioVersion, Annotation.version_id == AudioVersion.id) .join(Song, AudioVersion.song_id == Song.id) .where( Song.band_id == band_id, Annotation.type == "range", Annotation.deleted_at.is_(None), ) .options( selectinload(Annotation.range_analysis), selectinload(Annotation.author), ) .order_by(Annotation.created_at.desc()) ) result = await self.session.execute(stmt) return list(result.scalars().all())