114 lines
4.0 KiB
Python
Executable File
114 lines
4.0 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
import uuid
|
|
from typing import Any
|
|
|
|
from sqlalchemy import and_, select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from rehearsalhub.db.models import Annotation, RangeAnalysis
|
|
from rehearsalhub.repositories.base import BaseRepository
|
|
|
|
|
|
class AnnotationRepository(BaseRepository[Annotation]):
|
|
model = Annotation
|
|
|
|
async def list_for_version(self, version_id: uuid.UUID) -> list[Annotation]:
|
|
stmt = (
|
|
select(Annotation)
|
|
.where(
|
|
Annotation.version_id == version_id,
|
|
Annotation.deleted_at.is_(None),
|
|
)
|
|
.options(
|
|
selectinload(Annotation.range_analysis),
|
|
selectinload(Annotation.reactions),
|
|
selectinload(Annotation.author),
|
|
)
|
|
.order_by(Annotation.timestamp_ms)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def soft_delete(self, annotation: Annotation) -> None:
|
|
from datetime import datetime, timezone
|
|
|
|
annotation.deleted_at = datetime.now(timezone.utc)
|
|
await self.session.flush()
|
|
|
|
async def search_ranges(
|
|
self,
|
|
band_id: uuid.UUID,
|
|
bpm_min: float | None = None,
|
|
bpm_max: float | None = None,
|
|
key: str | None = None,
|
|
tag: str | None = None,
|
|
min_duration_ms: int | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
from rehearsalhub.db.models import AudioVersion, RangeAnalysis, Song
|
|
|
|
conditions = [
|
|
Song.band_id == band_id,
|
|
Annotation.type == "range",
|
|
Annotation.deleted_at.is_(None),
|
|
]
|
|
if bpm_min is not None:
|
|
conditions.append(RangeAnalysis.bpm >= bpm_min)
|
|
if bpm_max is not None:
|
|
conditions.append(RangeAnalysis.bpm <= bpm_max)
|
|
if key is not None:
|
|
conditions.append(RangeAnalysis.key.ilike(f"%{key}%"))
|
|
if tag is not None:
|
|
conditions.append(Annotation.tags.any(tag))
|
|
if min_duration_ms is not None:
|
|
conditions.append(
|
|
(Annotation.range_end_ms - Annotation.timestamp_ms) >= min_duration_ms
|
|
)
|
|
|
|
stmt = (
|
|
select(
|
|
Annotation.id.label("annotation_id"),
|
|
Song.title.label("song_title"),
|
|
Song.id.label("song_id"),
|
|
AudioVersion.id.label("version_id"),
|
|
AudioVersion.label.label("version_label"),
|
|
Annotation.timestamp_ms.label("start_ms"),
|
|
Annotation.range_end_ms.label("end_ms"),
|
|
Annotation.label.label("label"),
|
|
Annotation.tags.label("tags"),
|
|
RangeAnalysis.bpm,
|
|
RangeAnalysis.key,
|
|
RangeAnalysis.scale,
|
|
RangeAnalysis.avg_loudness_lufs,
|
|
RangeAnalysis.energy,
|
|
)
|
|
.join(AudioVersion, Annotation.version_id == AudioVersion.id)
|
|
.join(Song, AudioVersion.song_id == Song.id)
|
|
.join(RangeAnalysis, RangeAnalysis.annotation_id == Annotation.id)
|
|
.where(and_(*conditions))
|
|
.order_by(Annotation.timestamp_ms)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return [row._asdict() for row in result]
|
|
|
|
async def list_all_ranges_for_band(self, band_id: uuid.UUID) -> list[Annotation]:
|
|
from rehearsalhub.db.models import AudioVersion, Song
|
|
|
|
stmt = (
|
|
select(Annotation)
|
|
.join(AudioVersion, Annotation.version_id == AudioVersion.id)
|
|
.join(Song, AudioVersion.song_id == Song.id)
|
|
.where(
|
|
Song.band_id == band_id,
|
|
Annotation.type == "range",
|
|
Annotation.deleted_at.is_(None),
|
|
)
|
|
.options(
|
|
selectinload(Annotation.range_analysis),
|
|
selectinload(Annotation.author),
|
|
)
|
|
.order_by(Annotation.created_at.desc())
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|