diff --git a/api/src/rehearsalhub/db/engine.py b/api/src/rehearsalhub/db/engine.py index 89e57d0..bcfe520 100644 --- a/api/src/rehearsalhub/db/engine.py +++ b/api/src/rehearsalhub/db/engine.py @@ -39,10 +39,15 @@ def get_session_factory() -> async_sessionmaker[AsyncSession]: async def get_session() -> AsyncGenerator[AsyncSession, None]: """FastAPI dependency that yields an async DB session.""" + from rehearsalhub.queue.redis_queue import flush_pending_pushes + async with get_session_factory()() as session: try: yield session await session.commit() + # Fire any deferred Redis pushes AFTER commit so the worker always + # finds the job row already committed in the DB. + await flush_pending_pushes(session) except Exception: await session.rollback() raise diff --git a/api/src/rehearsalhub/queue/redis_queue.py b/api/src/rehearsalhub/queue/redis_queue.py index 6378c68..08f676e 100644 --- a/api/src/rehearsalhub/queue/redis_queue.py +++ b/api/src/rehearsalhub/queue/redis_queue.py @@ -3,11 +3,13 @@ Strategy: Postgres is the source of truth (durable audit log + retry counts). Redis holds a list of job UUIDs for fast signaling. Workers pop a UUID, load the full payload from Postgres, process, then update status in Postgres. + +The Redis push is deferred until AFTER the session commits so the worker +never reads a job ID that isn't yet visible in the DB. """ from __future__ import annotations -import json import uuid from datetime import datetime, timezone from typing import Any @@ -18,6 +20,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from rehearsalhub.config import get_settings from rehearsalhub.db.models import Job +_PENDING_ATTR = "_pending_redis_pushes" + class RedisJobQueue: def __init__(self, session: AsyncSession, redis_client: aioredis.Redis | None = None) -> None: @@ -34,24 +38,23 @@ class RedisJobQueue: self._session.add(job) await self._session.flush() await self._session.refresh(job) + job_id = job.id - r = await self._get_redis() - queue_key = get_settings().job_queue_key - await r.rpush(queue_key, str(job.id)) - return job.id + # Defer the Redis push until get_session commits, so the worker never + # reads a job ID that isn't yet visible in the DB. + pending: list = getattr(self._session, _PENDING_ATTR, None) # type: ignore[attr-defined] + if pending is None: + pending = [] + setattr(self._session, _PENDING_ATTR, pending) # type: ignore[attr-defined] - async def dequeue(self, timeout: int = 5) -> tuple[uuid.UUID, str, dict[str, Any]] | None: - r = await self._get_redis() + redis_client = await self._get_redis() queue_key = get_settings().job_queue_key - result = await r.blpop(queue_key, timeout=timeout) - if result is None: - return None - _, raw_id = result - job_id = uuid.UUID(raw_id) - job = await self._session.get(Job, job_id) - if job is None: - return None - return job.id, job.type, job.payload + + async def _push() -> None: + await redis_client.rpush(queue_key, str(job_id)) + + pending.append(_push) + return job_id async def mark_running(self, job_id: uuid.UUID) -> None: job = await self._session.get(Job, job_id) @@ -79,3 +82,13 @@ class RedisJobQueue: async def close(self) -> None: if self._redis: await self._redis.aclose() + + +async def flush_pending_pushes(session: AsyncSession) -> None: + """Called by get_session after commit() to fire deferred Redis pushes.""" + pending: list | None = getattr(session, _PENDING_ATTR, None) # type: ignore[attr-defined] + if not pending: + return + for push in pending: + await push() + pending.clear() diff --git a/api/src/rehearsalhub/routers/songs.py b/api/src/rehearsalhub/routers/songs.py index f5106b0..8b95bf2 100644 --- a/api/src/rehearsalhub/routers/songs.py +++ b/api/src/rehearsalhub/routers/songs.py @@ -99,61 +99,85 @@ async def scan_nextcloud( return href.lstrip("/") imported: list[SongRead] = [] + band_folder = band.nc_folder_path or f"bands/{band.slug}/" + + log.info("Starting NC scan for band '%s' in folder '%s'", band.slug, band_folder) try: - items = await nc.list_folder(band.nc_folder_path or f"bands/{band.slug}/") + items = await nc.list_folder(band_folder) except Exception as exc: raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"Nextcloud unreachable: {exc}") - # Collect (nc_file_path, song_folder_rel, song_title) tuples - to_import: list[tuple[str, str, str]] = [] + log.info("Found %d top-level entries in '%s'", len(items), band_folder) + + # Collect (nc_file_path, nc_folder, song_title, rehearsal_label) tuples. + # nc_folder is the directory that groups versions of the same song. + # For YYMMDD / dated rehearsal subfolders each file is its own song — + # the song title comes from the filename stem, not the folder name. + to_import: list[tuple[str, str, str, str | None]] = [] for item in items: rel = relative(item.path) if rel.endswith("/"): - # It's a subdirectory — scan one level deeper + dir_name = Path(rel.rstrip("/")).name try: sub_items = await nc.list_folder(rel) - except Exception: + except Exception as exc: + log.warning("Could not list subfolder '%s': %s", rel, exc) continue - dir_name = Path(rel.rstrip("/")).name - for sub in sub_items: + + audio_files = [s for s in sub_items if Path(relative(s.path)).suffix.lower() in AUDIO_EXTENSIONS] + log.info("Subfolder '%s': %d audio files found", dir_name, len(audio_files)) + + for sub in audio_files: sub_rel = relative(sub.path) - if Path(sub_rel).suffix.lower() in AUDIO_EXTENSIONS: - to_import.append((sub_rel, rel, dir_name)) + song_title = Path(sub_rel).stem + # Each file in a rehearsal folder is its own song, + # grouped under its own sub-subfolder path for version tracking. + song_folder = str(Path(sub_rel).parent) + "/" + rehearsal_label = dir_name # e.g. "231015" or "2023-10-15" + to_import.append((sub_rel, song_folder, song_title, rehearsal_label)) else: if Path(rel).suffix.lower() in AUDIO_EXTENSIONS: folder = str(Path(rel).parent) + "/" title = Path(rel).stem - to_import.append((rel, folder, title)) + to_import.append((rel, folder, title, None)) - for nc_file_path, nc_folder, song_title in to_import: - # Skip if version already registered by etag + log.info("NC scan: %d audio files to evaluate for import", len(to_import)) + + song_repo = SongRepository(session) + from rehearsalhub.schemas.audio_version import AudioVersionCreate # noqa: PLC0415 + + for nc_file_path, nc_folder, song_title, rehearsal_label in to_import: + # Skip if this exact file version is already registered try: meta = await nc.get_file_metadata(nc_file_path) etag = meta.etag - except Exception: - etag = None - - if etag and await version_repo.get_by_etag(etag): + except Exception as exc: + log.warning("Could not fetch metadata for '%s': %s — skipping", nc_file_path, exc) continue - # Find or create song - song_repo = SongRepository(session) + if etag and await version_repo.get_by_etag(etag): + log.debug("Skipping '%s' — etag already registered", nc_file_path) + continue + + # Find or create song record song = await song_repo.get_by_nc_folder_path(nc_folder) if song is None: song = await song_repo.get_by_title_and_band(band_id, song_title) if song is None: + log.info("Creating new song '%s' (folder: %s)", song_title, nc_folder) song = await song_repo.create( band_id=band_id, title=song_title, status="jam", - notes=None, + notes=f"Rehearsal: {rehearsal_label}" if rehearsal_label else None, nc_folder_path=nc_folder, created_by=current_member.id, ) + else: + log.info("Found existing song '%s' (id: %s)", song.title, song.id) - from rehearsalhub.schemas.audio_version import AudioVersionCreate # noqa: PLC0415 await song_svc.register_version( song.id, AudioVersionCreate( @@ -168,8 +192,10 @@ async def scan_nextcloud( read = SongRead.model_validate(song) read.version_count = 1 imported.append(read) - log.info("Imported %s as song '%s'", nc_file_path, song_title) + label_info = f" [rehearsal: {rehearsal_label}]" if rehearsal_label else "" + log.info("Imported '%s' as song '%s'%s", nc_file_path, song_title, label_info) + log.info("NC scan complete: %d new versions imported", len(imported)) return imported diff --git a/worker/src/worker/main.py b/worker/src/worker/main.py index 5a694cd..2c72db7 100644 --- a/worker/src/worker/main.py +++ b/worker/src/worker/main.py @@ -116,6 +116,13 @@ async def main() -> None: session_factory = async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession) redis = aioredis.from_url(settings.redis_url, decode_responses=True) + # Drain stale job IDs left in Redis from previous runs whose API transactions + # were never committed (e.g. crashed processes). + stale = await redis.llen(settings.job_queue_key) + if stale: + log.warning("Draining %d stale job IDs from Redis queue before starting", stale) + await redis.delete(settings.job_queue_key) + log.info("Worker started. Listening for jobs on %s", settings.job_queue_key) while True: @@ -125,11 +132,21 @@ async def main() -> None: continue _, raw_id = result job_id = uuid.UUID(raw_id) + log.info("Dequeued job %s", job_id) async with session_factory() as session: - job = await session.get(JobModel, job_id) + # Brief retry: the deferred Redis push fires right after API commit, + # so a tiny propagation delay is still possible. + job = None + for _attempt in range(3): + job = await session.get(JobModel, job_id) + if job is not None: + break + await asyncio.sleep(0.2) + await session.expire_all() + if job is None: - log.warning("Job %s not found in DB", job_id) + log.warning("Job %s not found in DB after retries — discarding", job_id) continue job.status = "running" @@ -139,18 +156,20 @@ async def main() -> None: handler = HANDLERS.get(job.type) if handler is None: - log.error("Unknown job type: %s", job.type) + log.error("Job %s has unknown type '%s' — marking failed", job_id, job.type) job.status = "failed" job.error = f"Unknown job type: {job.type}" job.finished_at = datetime.now(timezone.utc) await session.commit() continue + log.info("Running job %s type=%s payload=%s", job_id, job.type, job.payload) try: await handler(job.payload, session, settings) job.status = "done" job.finished_at = datetime.now(timezone.utc) await session.commit() + log.info("Job %s done", job_id) except Exception as exc: log.exception("Job %s failed: %s", job_id, exc) job.status = "failed"