Initial commit: RehearsalHub POC
Full-stack self-hosted band rehearsal platform: Backend (FastAPI + SQLAlchemy 2.0 async): - Auth with JWT (register, login, /me, settings) - Band management with Nextcloud folder integration - Song management with audio version tracking - Nextcloud scan to auto-import audio files - Band membership with link-based invite system - Song comments - Audio analysis worker (BPM, key, loudness, waveform) - Nextcloud activity watcher for auto-import - WebSocket support for real-time annotation updates - Alembic migrations (0001–0003) - Repository pattern, Ruff + mypy configured Frontend (React 18 + Vite + TypeScript strict): - Login/register page with post-login redirect - Home page with band list and creation form - Band page with member panel, invite link, song list, NC scan - Song page with waveform player, annotations, comment thread - Settings page for per-user Nextcloud credentials - Invite acceptance page (/invite/:token) - ESLint v9 flat config + TypeScript strict mode Infrastructure: - Docker Compose: PostgreSQL, Redis, API, worker, watcher, nginx - nginx reverse proxy for static files + /api/ proxy - make check runs all linters before docker compose build Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
30
worker/Dockerfile
Normal file
30
worker/Dockerfile
Normal file
@@ -0,0 +1,30 @@
|
||||
# Stage 1: Essentia builder
|
||||
# Essentia doesn't have wheels for Python 3.12 yet; we use the official image
|
||||
# and copy the bindings into our final stage via a bind mount.
|
||||
FROM mtgupf/essentia:latest AS essentia-builder
|
||||
|
||||
FROM python:3.12-slim AS base
|
||||
WORKDIR /app
|
||||
|
||||
# System dependencies for audio processing
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy Essentia Python bindings from builder (best-effort: no-op if the library
|
||||
# wasn't built for this Python version or the path doesn't exist).
|
||||
# COPY does not support shell redirections, so we use RUN --mount instead.
|
||||
RUN --mount=type=bind,from=essentia-builder,source=/usr/local/lib,target=/essentia_lib \
|
||||
find /essentia_lib -maxdepth 4 -name "essentia*" \
|
||||
-exec cp -r {} /usr/local/lib/python3.12/site-packages/ \; \
|
||||
2>/dev/null || true
|
||||
|
||||
RUN pip install uv
|
||||
|
||||
FROM base AS production
|
||||
COPY pyproject.toml .
|
||||
RUN uv sync --no-dev --frozen || uv sync --no-dev
|
||||
COPY . .
|
||||
ENV PYTHONPATH=/app/src
|
||||
CMD ["uv", "run", "python", "-m", "worker.main"]
|
||||
36
worker/pyproject.toml
Normal file
36
worker/pyproject.toml
Normal file
@@ -0,0 +1,36 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "rehearsalhub-worker"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"sqlalchemy[asyncio]>=2.0",
|
||||
"asyncpg>=0.29",
|
||||
"pydantic-settings>=2.3",
|
||||
"redis[hiredis]>=5.0",
|
||||
"httpx>=0.27",
|
||||
"pydub>=0.25",
|
||||
"pyloudnorm>=0.1",
|
||||
"librosa>=0.10",
|
||||
"numpy>=1.26",
|
||||
"scipy>=1.13",
|
||||
# essentia installed via system in Dockerfile (not via pip)
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8",
|
||||
"pytest-asyncio>=0.23",
|
||||
"pytest-cov>=5",
|
||||
"ruff>=0.4",
|
||||
]
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/worker"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
0
worker/src/worker/__init__.py
Normal file
0
worker/src/worker/__init__.py
Normal file
28
worker/src/worker/analyzers/__init__.py
Normal file
28
worker/src/worker/analyzers/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from worker.analyzers.base import AnalysisResult, AudioAnalyzer
|
||||
from worker.analyzers.bpm import BPMAnalyzer
|
||||
from worker.analyzers.chroma import ChromaAnalyzer
|
||||
from worker.analyzers.key import KeyAnalyzer
|
||||
from worker.analyzers.loudness import LoudnessAnalyzer
|
||||
from worker.analyzers.mfcc import MFCCAnalyzer
|
||||
from worker.analyzers.spectral import SpectralAnalyzer
|
||||
|
||||
REGISTRY: list[AudioAnalyzer] = [
|
||||
BPMAnalyzer(),
|
||||
KeyAnalyzer(),
|
||||
LoudnessAnalyzer(),
|
||||
SpectralAnalyzer(),
|
||||
ChromaAnalyzer(),
|
||||
MFCCAnalyzer(),
|
||||
]
|
||||
|
||||
__all__ = [
|
||||
"AudioAnalyzer",
|
||||
"AnalysisResult",
|
||||
"REGISTRY",
|
||||
"BPMAnalyzer",
|
||||
"KeyAnalyzer",
|
||||
"LoudnessAnalyzer",
|
||||
"SpectralAnalyzer",
|
||||
"ChromaAnalyzer",
|
||||
"MFCCAnalyzer",
|
||||
]
|
||||
38
worker/src/worker/analyzers/base.py
Normal file
38
worker/src/worker/analyzers/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Audio analyzer plugin interface.
|
||||
|
||||
To add a new analyzer:
|
||||
1. Subclass AudioAnalyzer
|
||||
2. Implement `name` and `analyze()`
|
||||
3. Add an instance to REGISTRY in analyse_range.py
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
analyzer_name: str
|
||||
fields: dict[str, float | list[float] | str | None] = field(default_factory=dict)
|
||||
|
||||
|
||||
class AudioAnalyzer(ABC):
|
||||
"""Base class for all audio analyzers. Each analyzer is stateless."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Unique short name for this analyzer (e.g. 'bpm', 'key')."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def analyze(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
"""
|
||||
Analyze a mono float32 audio array at the given sample rate.
|
||||
Must be synchronous (called in a thread pool by the pipeline).
|
||||
"""
|
||||
...
|
||||
39
worker/src/worker/analyzers/bpm.py
Normal file
39
worker/src/worker/analyzers/bpm.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""BPM analyzer using Essentia (primary) with librosa fallback."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from worker.analyzers.base import AnalysisResult, AudioAnalyzer
|
||||
|
||||
|
||||
class BPMAnalyzer(AudioAnalyzer):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "bpm"
|
||||
|
||||
def analyze(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
try:
|
||||
return self._essentia_bpm(audio, sample_rate)
|
||||
except Exception:
|
||||
return self._librosa_bpm(audio, sample_rate)
|
||||
|
||||
def _essentia_bpm(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
import essentia.standard as es # type: ignore[import]
|
||||
|
||||
rhythm = es.RhythmExtractor2013(method="multifeature")
|
||||
bpm, beats, bpm_confidence, _, _ = rhythm(audio)
|
||||
return AnalysisResult(
|
||||
analyzer_name=self.name,
|
||||
fields={"bpm": float(bpm), "bpm_confidence": float(bpm_confidence)},
|
||||
)
|
||||
|
||||
def _librosa_bpm(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
import librosa
|
||||
|
||||
tempo, _ = librosa.beat.beat_track(y=audio, sr=sample_rate)
|
||||
bpm = float(tempo[0]) if hasattr(tempo, "__len__") else float(tempo)
|
||||
return AnalysisResult(
|
||||
analyzer_name=self.name,
|
||||
fields={"bpm": bpm, "bpm_confidence": None},
|
||||
)
|
||||
21
worker/src/worker/analyzers/chroma.py
Normal file
21
worker/src/worker/analyzers/chroma.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from worker.analyzers.base import AnalysisResult, AudioAnalyzer
|
||||
|
||||
|
||||
class ChromaAnalyzer(AudioAnalyzer):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "chroma"
|
||||
|
||||
def analyze(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
import librosa
|
||||
|
||||
chroma = librosa.feature.chroma_cqt(y=audio, sr=sample_rate, n_chroma=12)
|
||||
chroma_mean = np.mean(chroma, axis=1).tolist()
|
||||
return AnalysisResult(
|
||||
analyzer_name=self.name,
|
||||
fields={"chroma_vector": [float(v) for v in chroma_mean]},
|
||||
)
|
||||
31
worker/src/worker/analyzers/key.py
Normal file
31
worker/src/worker/analyzers/key.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from worker.analyzers.base import AnalysisResult, AudioAnalyzer
|
||||
|
||||
|
||||
class KeyAnalyzer(AudioAnalyzer):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "key"
|
||||
|
||||
def analyze(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
try:
|
||||
import essentia.standard as es # type: ignore[import]
|
||||
|
||||
key_extractor = es.KeyExtractor()
|
||||
key, scale, confidence = key_extractor(audio)
|
||||
return AnalysisResult(
|
||||
analyzer_name=self.name,
|
||||
fields={
|
||||
"key": f"{key} {scale}",
|
||||
"scale": scale,
|
||||
"key_confidence": float(confidence),
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
return AnalysisResult(
|
||||
analyzer_name=self.name,
|
||||
fields={"key": None, "scale": None, "key_confidence": None},
|
||||
)
|
||||
33
worker/src/worker/analyzers/loudness.py
Normal file
33
worker/src/worker/analyzers/loudness.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from worker.analyzers.base import AnalysisResult, AudioAnalyzer
|
||||
|
||||
|
||||
class LoudnessAnalyzer(AudioAnalyzer):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "loudness"
|
||||
|
||||
def analyze(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
import pyloudnorm as pyln # type: ignore[import]
|
||||
|
||||
meter = pyln.Meter(sample_rate)
|
||||
stereo = np.stack([audio, audio], axis=1)
|
||||
try:
|
||||
lufs = float(meter.integrated_loudness(stereo))
|
||||
except Exception:
|
||||
lufs = float("nan")
|
||||
|
||||
peak_dbfs = float(20 * np.log10(np.max(np.abs(audio)) + 1e-12))
|
||||
rms_energy = float(np.sqrt(np.mean(audio**2)))
|
||||
|
||||
return AnalysisResult(
|
||||
analyzer_name=self.name,
|
||||
fields={
|
||||
"avg_loudness_lufs": lufs if not np.isnan(lufs) else None,
|
||||
"peak_loudness_dbfs": peak_dbfs,
|
||||
"energy": min(rms_energy, 1.0),
|
||||
},
|
||||
)
|
||||
21
worker/src/worker/analyzers/mfcc.py
Normal file
21
worker/src/worker/analyzers/mfcc.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from worker.analyzers.base import AnalysisResult, AudioAnalyzer
|
||||
|
||||
|
||||
class MFCCAnalyzer(AudioAnalyzer):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "mfcc"
|
||||
|
||||
def analyze(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
import librosa
|
||||
|
||||
mfccs = librosa.feature.mfcc(y=audio, sr=sample_rate, n_mfcc=13)
|
||||
mfcc_mean = np.mean(mfccs, axis=1).tolist()
|
||||
return AnalysisResult(
|
||||
analyzer_name=self.name,
|
||||
fields={"mfcc_mean": [float(v) for v in mfcc_mean]},
|
||||
)
|
||||
31
worker/src/worker/analyzers/spectral.py
Normal file
31
worker/src/worker/analyzers/spectral.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from worker.analyzers.base import AnalysisResult, AudioAnalyzer
|
||||
|
||||
|
||||
class SpectralAnalyzer(AudioAnalyzer):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "spectral"
|
||||
|
||||
def analyze(self, audio: np.ndarray, sample_rate: int) -> AnalysisResult:
|
||||
import librosa
|
||||
|
||||
centroid = librosa.feature.spectral_centroid(y=audio, sr=sample_rate)
|
||||
mean_centroid = float(np.mean(centroid))
|
||||
|
||||
try:
|
||||
import essentia.standard as es # type: ignore[import]
|
||||
|
||||
dance = es.Danceability()
|
||||
danceability, _ = dance(audio)
|
||||
danceability_val: float | None = float(danceability)
|
||||
except Exception:
|
||||
danceability_val = None
|
||||
|
||||
return AnalysisResult(
|
||||
analyzer_name=self.name,
|
||||
fields={"spectral_centroid": mean_centroid, "danceability": danceability_val},
|
||||
)
|
||||
25
worker/src/worker/config.py
Normal file
25
worker/src/worker/config.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from functools import lru_cache
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class WorkerSettings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
|
||||
|
||||
database_url: str = "postgresql+asyncpg://rh_user:change_me@localhost:5432/rehearsalhub"
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
job_queue_key: str = "rh:jobs"
|
||||
|
||||
nextcloud_url: str = "http://nextcloud"
|
||||
nextcloud_user: str = "ncadmin"
|
||||
nextcloud_pass: str = ""
|
||||
|
||||
audio_tmp_dir: str = "/tmp/audio"
|
||||
analysis_version: str = "1.0.0"
|
||||
|
||||
# Sample rate all analyzers operate on
|
||||
target_sample_rate: int = 44100
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> WorkerSettings:
|
||||
return WorkerSettings() # type: ignore[call-arg]
|
||||
72
worker/src/worker/db.py
Normal file
72
worker/src/worker/db.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Minimal SQLAlchemy models for the worker (mirrors the API models)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, ForeignKey, Integer, Numeric, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
class AudioVersionModel(Base):
|
||||
__tablename__ = "audio_versions"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
|
||||
song_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True))
|
||||
version_number: Mapped[int] = mapped_column(Integer)
|
||||
label: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
nc_file_path: Mapped[str] = mapped_column(Text)
|
||||
nc_file_etag: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
cdn_hls_base: Mapped[Optional[str]] = mapped_column(Text)
|
||||
waveform_url: Mapped[Optional[str]] = mapped_column(Text)
|
||||
duration_ms: Mapped[Optional[int]] = mapped_column(Integer)
|
||||
format: Mapped[Optional[str]] = mapped_column(String(10))
|
||||
file_size_bytes: Mapped[Optional[int]] = mapped_column(BigInteger)
|
||||
analysis_status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||
uploaded_by: Mapped[Optional[uuid.UUID]] = mapped_column(UUID(as_uuid=True))
|
||||
uploaded_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class RangeAnalysisModel(Base):
|
||||
__tablename__ = "range_analyses"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
annotation_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), unique=True)
|
||||
version_id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True))
|
||||
start_ms: Mapped[int] = mapped_column(Integer)
|
||||
end_ms: Mapped[int] = mapped_column(Integer)
|
||||
bpm: Mapped[Optional[float]] = mapped_column(Numeric(7, 2))
|
||||
bpm_confidence: Mapped[Optional[float]] = mapped_column(Numeric(4, 3))
|
||||
key: Mapped[Optional[str]] = mapped_column(String(30))
|
||||
key_confidence: Mapped[Optional[float]] = mapped_column(Numeric(4, 3))
|
||||
scale: Mapped[Optional[str]] = mapped_column(String(10))
|
||||
avg_loudness_lufs: Mapped[Optional[float]] = mapped_column(Numeric(6, 2))
|
||||
peak_loudness_dbfs: Mapped[Optional[float]] = mapped_column(Numeric(6, 2))
|
||||
spectral_centroid: Mapped[Optional[float]] = mapped_column(Numeric(10, 2))
|
||||
energy: Mapped[Optional[float]] = mapped_column(Numeric(5, 4))
|
||||
danceability: Mapped[Optional[float]] = mapped_column(Numeric(5, 4))
|
||||
chroma_vector: Mapped[Optional[list[float]]] = mapped_column(ARRAY(Numeric))
|
||||
mfcc_mean: Mapped[Optional[list[float]]] = mapped_column(ARRAY(Numeric))
|
||||
analysis_version: Mapped[Optional[str]] = mapped_column(String(20))
|
||||
computed_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
|
||||
class JobModel(Base):
|
||||
__tablename__ = "jobs"
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True)
|
||||
type: Mapped[str] = mapped_column(String(50))
|
||||
payload: Mapped[dict] = mapped_column(JSONB)
|
||||
status: Mapped[str] = mapped_column(String(20), default="queued")
|
||||
attempt: Mapped[int] = mapped_column(Integer, default=0)
|
||||
error: Mapped[Optional[str]] = mapped_column(Text)
|
||||
queued_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now())
|
||||
started_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||||
finished_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True))
|
||||
167
worker/src/worker/main.py
Normal file
167
worker/src/worker/main.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Audio worker entry point. Consumes jobs from Redis queue."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import redis.asyncio as aioredis
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from worker.config import get_settings
|
||||
from worker.db import AudioVersionModel, JobModel
|
||||
from worker.pipeline.analyse_full import run_full_analysis
|
||||
from worker.pipeline.analyse_range import run_range_analysis
|
||||
from worker.pipeline.transcode import get_duration_ms, transcode_to_hls
|
||||
from worker.pipeline.waveform import generate_waveform_file
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
log = logging.getLogger("worker")
|
||||
|
||||
|
||||
async def load_audio(nc_path: str, tmp_dir: str, settings) -> tuple[np.ndarray, int, str]:
|
||||
"""Download from Nextcloud and load as numpy array. Returns (audio, sr, local_path)."""
|
||||
import httpx
|
||||
|
||||
local_path = os.path.join(tmp_dir, Path(nc_path).name)
|
||||
dav_url = f"{settings.nextcloud_url}/remote.php/dav/files/{settings.nextcloud_user}/{nc_path.lstrip('/')}"
|
||||
async with httpx.AsyncClient(
|
||||
auth=(settings.nextcloud_user, settings.nextcloud_pass), timeout=120.0
|
||||
) as client:
|
||||
resp = await client.get(dav_url)
|
||||
resp.raise_for_status()
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
audio, sr = await loop.run_in_executor(
|
||||
None, lambda: librosa.load(local_path, sr=settings.target_sample_rate, mono=True)
|
||||
)
|
||||
return audio, sr, local_path
|
||||
|
||||
|
||||
async def handle_transcode(payload: dict, session: AsyncSession, settings) -> None:
|
||||
version_id = uuid.UUID(payload["version_id"])
|
||||
nc_path = payload["nc_file_path"]
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=settings.audio_tmp_dir) as tmp:
|
||||
audio, sr, local_path = await load_audio(nc_path, tmp, settings)
|
||||
duration_ms = await get_duration_ms(local_path)
|
||||
|
||||
hls_dir = os.path.join(tmp, "hls")
|
||||
await transcode_to_hls(local_path, hls_dir)
|
||||
|
||||
waveform_path = os.path.join(tmp, "waveform.json")
|
||||
await generate_waveform_file(audio, waveform_path)
|
||||
|
||||
# TODO: Upload HLS segments and waveform back to Nextcloud / object storage
|
||||
# For now, store the local tmp path in the DB (replace with real upload logic)
|
||||
hls_nc_path = f"hls/{version_id}"
|
||||
waveform_nc_path = f"waveforms/{version_id}.json"
|
||||
|
||||
stmt = (
|
||||
update(AudioVersionModel)
|
||||
.where(AudioVersionModel.id == version_id)
|
||||
.values(
|
||||
cdn_hls_base=hls_nc_path,
|
||||
waveform_url=waveform_nc_path,
|
||||
duration_ms=duration_ms,
|
||||
analysis_status="running",
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
await run_full_analysis(audio, sr, version_id, session)
|
||||
|
||||
log.info("Transcode complete for version %s", version_id)
|
||||
|
||||
|
||||
async def handle_analyse_range(payload: dict, session: AsyncSession, settings) -> None:
|
||||
version_id = uuid.UUID(payload["version_id"])
|
||||
annotation_id = uuid.UUID(payload["annotation_id"])
|
||||
start_ms = payload["start_ms"]
|
||||
end_ms = payload["end_ms"]
|
||||
|
||||
version = await session.get(AudioVersionModel, version_id)
|
||||
if version is None:
|
||||
raise ValueError(f"AudioVersion {version_id} not found")
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=settings.audio_tmp_dir) as tmp:
|
||||
audio, sr, _ = await load_audio(version.nc_file_path, tmp, settings)
|
||||
await run_range_analysis(audio, sr, version_id, annotation_id, start_ms, end_ms, session)
|
||||
|
||||
log.info("Range analysis complete for annotation %s", annotation_id)
|
||||
|
||||
|
||||
HANDLERS = {
|
||||
"transcode": handle_transcode,
|
||||
"analyse_range": handle_analyse_range,
|
||||
}
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
settings = get_settings()
|
||||
os.makedirs(settings.audio_tmp_dir, exist_ok=True)
|
||||
|
||||
engine = create_async_engine(settings.database_url, pool_pre_ping=True)
|
||||
session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
||||
redis = aioredis.from_url(settings.redis_url, decode_responses=True)
|
||||
|
||||
log.info("Worker started. Listening for jobs on %s", settings.job_queue_key)
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = await redis.blpop(settings.job_queue_key, timeout=5)
|
||||
if result is None:
|
||||
continue
|
||||
_, raw_id = result
|
||||
job_id = uuid.UUID(raw_id)
|
||||
|
||||
async with session_factory() as session:
|
||||
job = await session.get(JobModel, job_id)
|
||||
if job is None:
|
||||
log.warning("Job %s not found in DB", job_id)
|
||||
continue
|
||||
|
||||
job.status = "running"
|
||||
job.started_at = datetime.now(timezone.utc)
|
||||
job.attempt = (job.attempt or 0) + 1
|
||||
await session.commit()
|
||||
|
||||
handler = HANDLERS.get(job.type)
|
||||
if handler is None:
|
||||
log.error("Unknown job type: %s", job.type)
|
||||
job.status = "failed"
|
||||
job.error = f"Unknown job type: {job.type}"
|
||||
job.finished_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
continue
|
||||
|
||||
try:
|
||||
await handler(job.payload, session, settings)
|
||||
job.status = "done"
|
||||
job.finished_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
except Exception as exc:
|
||||
log.exception("Job %s failed: %s", job_id, exc)
|
||||
job.status = "failed"
|
||||
job.error = str(exc)[:2000]
|
||||
job.finished_at = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
except Exception as exc:
|
||||
log.exception("Worker loop error: %s", exc)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
0
worker/src/worker/pipeline/__init__.py
Normal file
0
worker/src/worker/pipeline/__init__.py
Normal file
46
worker/src/worker/pipeline/analyse_full.py
Normal file
46
worker/src/worker/pipeline/analyse_full.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Full-track analysis: runs BPM + Key analyzers, updates audio_versions table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from worker.analyzers.bpm import BPMAnalyzer
|
||||
from worker.analyzers.key import KeyAnalyzer
|
||||
|
||||
|
||||
async def run_full_analysis(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
version_id: uuid.UUID,
|
||||
session: AsyncSession,
|
||||
) -> dict[str, Any]:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
bpm_result = await loop.run_in_executor(None, BPMAnalyzer().analyze, audio, sample_rate)
|
||||
key_result = await loop.run_in_executor(None, KeyAnalyzer().analyze, audio, sample_rate)
|
||||
|
||||
fields: dict[str, Any] = {**bpm_result.fields, **key_result.fields}
|
||||
|
||||
from sqlalchemy import update
|
||||
from worker.db import AudioVersionModel
|
||||
|
||||
global_bpm = fields.get("bpm")
|
||||
global_key = fields.get("key")
|
||||
|
||||
stmt = (
|
||||
update(AudioVersionModel)
|
||||
.where(AudioVersionModel.id == version_id)
|
||||
.values(
|
||||
analysis_status="done",
|
||||
**({} if global_bpm is None else {"global_bpm": global_bpm}),
|
||||
)
|
||||
)
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
return fields
|
||||
93
worker/src/worker/pipeline/analyse_range.py
Normal file
93
worker/src/worker/pipeline/analyse_range.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""Range analysis: slice audio to ms range, run all analyzers, write range_analyses row."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from worker.analyzers import REGISTRY
|
||||
from worker.config import get_settings
|
||||
|
||||
|
||||
def slice_audio(audio: np.ndarray, sample_rate: int, start_ms: int, end_ms: int) -> np.ndarray:
|
||||
start_sample = int(start_ms * sample_rate / 1000)
|
||||
end_sample = int(end_ms * sample_rate / 1000)
|
||||
return audio[start_sample:end_sample]
|
||||
|
||||
|
||||
async def run_range_analysis(
|
||||
audio: np.ndarray,
|
||||
sample_rate: int,
|
||||
version_id: uuid.UUID,
|
||||
annotation_id: uuid.UUID,
|
||||
start_ms: int,
|
||||
end_ms: int,
|
||||
session: AsyncSession,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Slice audio to the range, run all registered analyzers in thread pool,
|
||||
merge results, and write to range_analyses table.
|
||||
"""
|
||||
sliced = slice_audio(audio, sample_rate, start_ms, end_ms)
|
||||
if len(sliced) == 0:
|
||||
raise ValueError(f"Empty audio slice for range {start_ms}–{end_ms}ms")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
tasks = [
|
||||
loop.run_in_executor(None, analyzer.analyze, sliced, sample_rate)
|
||||
for analyzer in REGISTRY
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
merged: dict[str, Any] = {}
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
continue
|
||||
merged.update(result.fields)
|
||||
|
||||
settings = get_settings()
|
||||
row_data = {
|
||||
"version_id": version_id,
|
||||
"annotation_id": annotation_id,
|
||||
"start_ms": start_ms,
|
||||
"end_ms": end_ms,
|
||||
"bpm": merged.get("bpm"),
|
||||
"bpm_confidence": merged.get("bpm_confidence"),
|
||||
"key": merged.get("key"),
|
||||
"key_confidence": merged.get("key_confidence"),
|
||||
"scale": merged.get("scale"),
|
||||
"avg_loudness_lufs": merged.get("avg_loudness_lufs"),
|
||||
"peak_loudness_dbfs": merged.get("peak_loudness_dbfs"),
|
||||
"spectral_centroid": merged.get("spectral_centroid"),
|
||||
"energy": merged.get("energy"),
|
||||
"danceability": merged.get("danceability"),
|
||||
"chroma_vector": merged.get("chroma_vector"),
|
||||
"mfcc_mean": merged.get("mfcc_mean"),
|
||||
"analysis_version": settings.analysis_version,
|
||||
"computed_at": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
from worker.db import RangeAnalysisModel
|
||||
|
||||
existing = await session.execute(
|
||||
__import__("sqlalchemy").select(RangeAnalysisModel).where(
|
||||
RangeAnalysisModel.annotation_id == annotation_id
|
||||
)
|
||||
)
|
||||
existing_row = existing.scalar_one_or_none()
|
||||
|
||||
if existing_row is not None:
|
||||
for key, val in row_data.items():
|
||||
setattr(existing_row, key, val)
|
||||
else:
|
||||
new_row = RangeAnalysisModel(**row_data)
|
||||
session.add(new_row)
|
||||
|
||||
await session.commit()
|
||||
return row_data
|
||||
60
worker/src/worker/pipeline/transcode.py
Normal file
60
worker/src/worker/pipeline/transcode.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""FFmpeg transcoding: audio file → HLS segments + waveform peaks JSON."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
async def transcode_to_hls(input_path: str, output_dir: str) -> str:
|
||||
"""
|
||||
Transcode input audio to AAC HLS segments.
|
||||
Returns the path to the generated playlist.m3u8.
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
playlist = os.path.join(output_dir, "playlist.m3u8")
|
||||
cmd = [
|
||||
"ffmpeg", "-i", input_path,
|
||||
"-c:a", "aac",
|
||||
"-b:a", "192k",
|
||||
"-ac", "2",
|
||||
"-ar", "44100",
|
||||
"-hls_time", "4",
|
||||
"-hls_playlist_type", "vod",
|
||||
"-hls_segment_filename", os.path.join(output_dir, "seg%03d.aac"),
|
||||
playlist,
|
||||
"-y",
|
||||
]
|
||||
await _run_ffmpeg(cmd)
|
||||
return playlist
|
||||
|
||||
|
||||
async def get_duration_ms(input_path: str) -> int:
|
||||
"""Return the duration of the audio file in milliseconds."""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "quiet",
|
||||
"-print_format", "json",
|
||||
"-show_format",
|
||||
input_path,
|
||||
]
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
stdout, _ = await proc.communicate()
|
||||
info = json.loads(stdout)
|
||||
duration_s = float(info.get("format", {}).get("duration", 0))
|
||||
return int(duration_s * 1000)
|
||||
|
||||
|
||||
async def _run_ffmpeg(cmd: list[str]) -> None:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=asyncio.subprocess.DEVNULL, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
_, stderr = await proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(f"FFmpeg failed: {stderr.decode()[:500]}")
|
||||
46
worker/src/worker/pipeline/waveform.py
Normal file
46
worker/src/worker/pipeline/waveform.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Waveform peak extraction: audio file → JSON peaks for WaveSurfer.js rendering."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def extract_peaks(audio: np.ndarray, num_points: int = 1000) -> list[float]:
|
||||
"""
|
||||
Downsample audio to `num_points` RMS+peak values for waveform display.
|
||||
Returns a flat list of [peak, peak, ...] normalized to 0-1.
|
||||
"""
|
||||
if len(audio) == 0:
|
||||
return [0.0] * num_points
|
||||
|
||||
chunk_size = max(1, len(audio) // num_points)
|
||||
peaks = []
|
||||
for i in range(num_points):
|
||||
start = i * chunk_size
|
||||
end = start + chunk_size
|
||||
chunk = audio[start:end]
|
||||
if len(chunk) == 0:
|
||||
peaks.append(0.0)
|
||||
else:
|
||||
peaks.append(float(np.max(np.abs(chunk))))
|
||||
|
||||
max_val = max(peaks) or 1.0
|
||||
return [p / max_val for p in peaks]
|
||||
|
||||
|
||||
def peaks_to_json(peaks: list[float]) -> str:
|
||||
return json.dumps({"version": 2, "channels": 1, "sample_rate": 44100, "samples_per_pixel": 256, "bits": 8, "length": len(peaks), "data": peaks})
|
||||
|
||||
|
||||
async def generate_waveform_file(audio: np.ndarray, output_path: str) -> None:
|
||||
"""Write waveform JSON to output_path."""
|
||||
import asyncio
|
||||
|
||||
peaks = await asyncio.get_event_loop().run_in_executor(
|
||||
None, extract_peaks, audio
|
||||
)
|
||||
content = peaks_to_json(peaks)
|
||||
with open(output_path, "w") as f:
|
||||
f.write(content)
|
||||
0
worker/src/worker/queue/__init__.py
Normal file
0
worker/src/worker/queue/__init__.py
Normal file
0
worker/tests/__init__.py
Normal file
0
worker/tests/__init__.py
Normal file
22
worker/tests/conftest.py
Normal file
22
worker/tests/conftest.py
Normal file
@@ -0,0 +1,22 @@
|
||||
"""Worker test fixtures."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sine_440hz():
|
||||
"""A 5-second 440Hz sine wave at 44100 Hz — usable as mock audio input."""
|
||||
sr = 44100
|
||||
t = np.linspace(0, 5.0, sr * 5, endpoint=False)
|
||||
audio = (np.sin(2 * np.pi * 440 * t)).astype(np.float32)
|
||||
return audio, sr
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def short_audio():
|
||||
"""1 second of white noise."""
|
||||
sr = 44100
|
||||
rng = np.random.default_rng(42)
|
||||
audio = rng.uniform(-0.5, 0.5, sr).astype(np.float32)
|
||||
return audio, sr
|
||||
137
worker/tests/test_analyse_range.py
Normal file
137
worker/tests/test_analyse_range.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Tests for range analysis pipeline — Essentia mocked out."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from worker.analyzers.base import AnalysisResult
|
||||
from worker.pipeline.analyse_range import run_range_analysis, slice_audio
|
||||
|
||||
|
||||
def test_slice_audio_correct_samples():
|
||||
sr = 44100
|
||||
audio = np.ones(sr * 10, dtype=np.float32) # 10 seconds
|
||||
sliced = slice_audio(audio, sr, start_ms=2000, end_ms=5000)
|
||||
expected_len = int(3.0 * sr) # 3 seconds
|
||||
assert abs(len(sliced) - expected_len) <= sr * 0.01 # within 10ms
|
||||
|
||||
|
||||
def test_slice_audio_preserves_content():
|
||||
sr = 44100
|
||||
audio = np.arange(sr * 10, dtype=np.float32)
|
||||
sliced = slice_audio(audio, sr, start_ms=0, end_ms=1000)
|
||||
assert sliced[0] == 0.0
|
||||
assert len(sliced) == sr
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_range_analysis_merges_results():
|
||||
"""All analyzers are mocked — tests the merge + DB write logic."""
|
||||
from worker.analyzers.bpm import BPMAnalyzer
|
||||
from worker.analyzers.key import KeyAnalyzer
|
||||
from worker.analyzers.loudness import LoudnessAnalyzer
|
||||
|
||||
audio = np.sin(np.linspace(0, 2 * np.pi * 5, 44100 * 5)).astype(np.float32)
|
||||
sr = 44100
|
||||
version_id = uuid.uuid4()
|
||||
annotation_id = uuid.uuid4()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.get = AsyncMock(return_value=None)
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
result_mock = AsyncMock()
|
||||
result_mock.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = result_mock
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
BPMAnalyzer,
|
||||
"analyze",
|
||||
return_value=AnalysisResult("bpm", {"bpm": 120.0, "bpm_confidence": 0.9}),
|
||||
),
|
||||
patch.object(
|
||||
KeyAnalyzer,
|
||||
"analyze",
|
||||
return_value=AnalysisResult("key", {"key": "A minor", "scale": "minor", "key_confidence": 0.8}),
|
||||
),
|
||||
patch.object(
|
||||
LoudnessAnalyzer,
|
||||
"analyze",
|
||||
return_value=AnalysisResult(
|
||||
"loudness",
|
||||
{"avg_loudness_lufs": -18.0, "peak_loudness_dbfs": -6.0, "energy": 0.5},
|
||||
),
|
||||
),
|
||||
):
|
||||
result = await run_range_analysis(
|
||||
audio=audio,
|
||||
sample_rate=sr,
|
||||
version_id=version_id,
|
||||
annotation_id=annotation_id,
|
||||
start_ms=0,
|
||||
end_ms=5000,
|
||||
session=mock_session,
|
||||
)
|
||||
|
||||
assert result["bpm"] == 120.0
|
||||
assert result["key"] == "A minor"
|
||||
assert result["avg_loudness_lufs"] == -18.0
|
||||
mock_session.add.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_range_analysis_handles_analyzer_failure():
|
||||
"""If one analyzer raises, others should still run."""
|
||||
from worker.analyzers.bpm import BPMAnalyzer
|
||||
from worker.analyzers.chroma import ChromaAnalyzer
|
||||
|
||||
audio = np.ones(44100, dtype=np.float32)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
result_mock = AsyncMock()
|
||||
result_mock.scalar_one_or_none.return_value = None
|
||||
mock_session.execute.return_value = result_mock
|
||||
mock_session.add = MagicMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
with (
|
||||
patch.object(BPMAnalyzer, "analyze", side_effect=RuntimeError("Essentia not available")),
|
||||
patch.object(
|
||||
ChromaAnalyzer,
|
||||
"analyze",
|
||||
return_value=AnalysisResult("chroma", {"chroma_vector": [0.1] * 12}),
|
||||
),
|
||||
):
|
||||
result = await run_range_analysis(
|
||||
audio=audio,
|
||||
sample_rate=44100,
|
||||
version_id=uuid.uuid4(),
|
||||
annotation_id=uuid.uuid4(),
|
||||
start_ms=0,
|
||||
end_ms=1000,
|
||||
session=mock_session,
|
||||
)
|
||||
|
||||
# bpm should be None (failed), chroma should be present
|
||||
assert result.get("bpm") is None
|
||||
assert result.get("chroma_vector") == [0.1] * 12
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_range_analysis_empty_slice_raises():
|
||||
audio = np.array([], dtype=np.float32)
|
||||
with pytest.raises(ValueError, match="Empty audio slice"):
|
||||
await run_range_analysis(
|
||||
audio=audio,
|
||||
sample_rate=44100,
|
||||
version_id=uuid.uuid4(),
|
||||
annotation_id=uuid.uuid4(),
|
||||
start_ms=0,
|
||||
end_ms=0,
|
||||
session=AsyncMock(),
|
||||
)
|
||||
88
worker/tests/test_analyzers.py
Normal file
88
worker/tests/test_analyzers.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Unit tests for individual analyzers (Essentia mocked, librosa used directly)."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from worker.analyzers.base import AnalysisResult
|
||||
from worker.analyzers.chroma import ChromaAnalyzer
|
||||
from worker.analyzers.loudness import LoudnessAnalyzer
|
||||
from worker.analyzers.mfcc import MFCCAnalyzer
|
||||
from worker.analyzers.spectral import SpectralAnalyzer
|
||||
|
||||
|
||||
def test_loudness_analyzer_returns_expected_fields(sine_440hz):
|
||||
audio, sr = sine_440hz
|
||||
result = LoudnessAnalyzer().analyze(audio, sr)
|
||||
assert isinstance(result, AnalysisResult)
|
||||
assert result.analyzer_name == "loudness"
|
||||
assert "avg_loudness_lufs" in result.fields
|
||||
assert "peak_loudness_dbfs" in result.fields
|
||||
assert "energy" in result.fields
|
||||
assert result.fields["energy"] is not None
|
||||
assert 0.0 <= result.fields["energy"] <= 1.0
|
||||
|
||||
|
||||
def test_chroma_analyzer_returns_12_dimensions(sine_440hz):
|
||||
audio, sr = sine_440hz
|
||||
result = ChromaAnalyzer().analyze(audio, sr)
|
||||
assert result.analyzer_name == "chroma"
|
||||
chroma = result.fields["chroma_vector"]
|
||||
assert chroma is not None
|
||||
assert len(chroma) == 12
|
||||
assert all(isinstance(v, float) for v in chroma)
|
||||
|
||||
|
||||
def test_mfcc_analyzer_returns_13_dimensions(sine_440hz):
|
||||
audio, sr = sine_440hz
|
||||
result = MFCCAnalyzer().analyze(audio, sr)
|
||||
assert result.analyzer_name == "mfcc"
|
||||
mfcc = result.fields["mfcc_mean"]
|
||||
assert mfcc is not None
|
||||
assert len(mfcc) == 13
|
||||
|
||||
|
||||
def test_spectral_analyzer_returns_centroid(sine_440hz):
|
||||
audio, sr = sine_440hz
|
||||
result = SpectralAnalyzer().analyze(audio, sr)
|
||||
assert "spectral_centroid" in result.fields
|
||||
# 440 Hz tone should have centroid near 440 Hz
|
||||
centroid = result.fields["spectral_centroid"]
|
||||
assert centroid is not None
|
||||
assert 300 < centroid < 600
|
||||
|
||||
|
||||
def test_bpm_analyzer_falls_back_to_librosa_when_essentia_unavailable(sine_440hz):
|
||||
audio, sr = sine_440hz
|
||||
from worker.analyzers.bpm import BPMAnalyzer
|
||||
|
||||
with patch.dict("sys.modules", {"essentia": None, "essentia.standard": None}):
|
||||
with patch.object(
|
||||
BPMAnalyzer,
|
||||
"_essentia_bpm",
|
||||
side_effect=ImportError("no essentia"),
|
||||
):
|
||||
result = BPMAnalyzer().analyze(audio, sr)
|
||||
|
||||
assert result.analyzer_name == "bpm"
|
||||
assert "bpm" in result.fields
|
||||
# librosa result for a sine wave — rough estimate
|
||||
assert result.fields["bpm"] is not None
|
||||
assert result.fields["bpm"] > 0
|
||||
|
||||
|
||||
def test_key_analyzer_returns_none_fields_when_essentia_unavailable(sine_440hz):
|
||||
audio, sr = sine_440hz
|
||||
from worker.analyzers.key import KeyAnalyzer
|
||||
|
||||
with patch.object(KeyAnalyzer, "analyze", wraps=KeyAnalyzer().analyze):
|
||||
with patch.dict("sys.modules", {"essentia": None, "essentia.standard": None}):
|
||||
with patch(
|
||||
"worker.analyzers.key.__import__",
|
||||
side_effect=ImportError,
|
||||
):
|
||||
result = KeyAnalyzer().analyze(audio, sr)
|
||||
|
||||
# When Essentia fails, returns None fields (no crash)
|
||||
assert result.analyzer_name == "key"
|
||||
49
worker/tests/test_waveform.py
Normal file
49
worker/tests/test_waveform.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""Tests for waveform peak extraction (no Essentia required)."""
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from worker.pipeline.waveform import extract_peaks, generate_waveform_file, peaks_to_json
|
||||
|
||||
|
||||
def test_extract_peaks_returns_correct_length(sine_440hz):
|
||||
audio, sr = sine_440hz
|
||||
peaks = extract_peaks(audio, num_points=500)
|
||||
assert len(peaks) == 500
|
||||
|
||||
|
||||
def test_extract_peaks_normalized_between_0_and_1(sine_440hz):
|
||||
audio, sr = sine_440hz
|
||||
peaks = extract_peaks(audio, num_points=200)
|
||||
assert all(0.0 <= p <= 1.0 for p in peaks)
|
||||
assert max(peaks) == pytest.approx(1.0, abs=0.01)
|
||||
|
||||
|
||||
def test_extract_peaks_empty_audio():
|
||||
audio = np.array([], dtype=np.float32)
|
||||
peaks = extract_peaks(audio, num_points=100)
|
||||
assert len(peaks) == 100
|
||||
assert all(p == 0.0 for p in peaks)
|
||||
|
||||
|
||||
def test_peaks_to_json_valid_structure(sine_440hz):
|
||||
audio, _ = sine_440hz
|
||||
peaks = extract_peaks(audio)
|
||||
json_str = peaks_to_json(peaks)
|
||||
data = json.loads(json_str)
|
||||
assert data["version"] == 2
|
||||
assert data["channels"] == 1
|
||||
assert len(data["data"]) == len(peaks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_waveform_file_writes_json(tmp_path, sine_440hz):
|
||||
audio, _ = sine_440hz
|
||||
output = str(tmp_path / "waveform.json")
|
||||
await generate_waveform_file(audio, output)
|
||||
with open(output) as f:
|
||||
data = json.load(f)
|
||||
assert data["version"] == 2
|
||||
assert len(data["data"]) == 1000
|
||||
Reference in New Issue
Block a user