"""Redis-backed job queue. Strategy: Postgres is the source of truth (durable audit log + retry counts). Redis holds a list of job UUIDs for fast signaling. Workers pop a UUID, load the full payload from Postgres, process, then update status in Postgres. The Redis push is deferred until AFTER the session commits so the worker never reads a job ID that isn't yet visible in the DB. """ from __future__ import annotations import uuid from datetime import datetime, timezone from typing import Any import redis.asyncio as aioredis from sqlalchemy.ext.asyncio import AsyncSession from rehearsalhub.config import get_settings from rehearsalhub.db.models import Job _PENDING_ATTR = "_pending_redis_pushes" class RedisJobQueue: def __init__(self, session: AsyncSession, redis_client: aioredis.Redis | None = None) -> None: self._session = session self._redis: aioredis.Redis | None = redis_client async def _get_redis(self) -> aioredis.Redis: if self._redis is None: self._redis = aioredis.from_url(get_settings().redis_url, decode_responses=True) return self._redis async def enqueue(self, job_type: str, payload: dict[str, Any]) -> uuid.UUID: job = Job(type=job_type, payload=payload, status="queued") self._session.add(job) await self._session.flush() await self._session.refresh(job) job_id = job.id # Defer the Redis push until get_session commits, so the worker never # reads a job ID that isn't yet visible in the DB. pending: list = getattr(self._session, _PENDING_ATTR, None) # type: ignore[attr-defined] if pending is None: pending = [] setattr(self._session, _PENDING_ATTR, pending) # type: ignore[attr-defined] redis_client = await self._get_redis() queue_key = get_settings().job_queue_key async def _push() -> None: await redis_client.rpush(queue_key, str(job_id)) pending.append(_push) return job_id async def mark_running(self, job_id: uuid.UUID) -> None: job = await self._session.get(Job, job_id) if job: job.status = "running" job.started_at = datetime.now(timezone.utc) job.attempt = (job.attempt or 0) + 1 await self._session.flush() async def mark_done(self, job_id: uuid.UUID) -> None: job = await self._session.get(Job, job_id) if job: job.status = "done" job.finished_at = datetime.now(timezone.utc) await self._session.flush() async def mark_failed(self, job_id: uuid.UUID, error: str) -> None: job = await self._session.get(Job, job_id) if job: job.status = "failed" job.error = error[:2000] job.finished_at = datetime.now(timezone.utc) await self._session.flush() async def dequeue(self, timeout: int = 5) -> tuple[uuid.UUID, str, dict[str, Any]] | None: """Block up to `timeout` seconds for a job. Returns (id, type, payload) or None.""" redis_client = await self._get_redis() queue_key = get_settings().job_queue_key result = await redis_client.blpop(queue_key, timeout=timeout) if result is None: return None _, raw_id = result job_id = uuid.UUID(raw_id) job = await self._session.get(Job, job_id) if job is None: return None return job_id, job.type, job.payload # type: ignore[return-value] async def close(self) -> None: if self._redis: await self._redis.aclose() async def flush_pending_pushes(session: AsyncSession) -> None: """Called by get_session after commit() to fire deferred Redis pushes.""" pending: list | None = getattr(session, _PENDING_ATTR, None) # type: ignore[attr-defined] if not pending: return for push in pending: await push() pending.clear()