"""Generic async repository. All concrete repos extend BaseRepository[T].""" from __future__ import annotations import uuid from typing import Any, Generic, Sequence, TypeVar from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from rehearsalhub.db.models import Base ModelT = TypeVar("ModelT", bound=Base) class BaseRepository(Generic[ModelT]): model: type[ModelT] def __init__(self, session: AsyncSession) -> None: self.session = session async def get_by_id(self, id: uuid.UUID) -> ModelT | None: return await self.session.get(self.model, id) async def list(self, **filters: Any) -> Sequence[ModelT]: stmt = select(self.model).filter_by(**filters) result = await self.session.execute(stmt) return result.scalars().all() async def create(self, **kwargs: Any) -> ModelT: obj = self.model(**kwargs) self.session.add(obj) await self.session.flush() await self.session.refresh(obj) return obj async def update(self, obj: ModelT, **kwargs: Any) -> ModelT: for key, value in kwargs.items(): setattr(obj, key, value) await self.session.flush() await self.session.refresh(obj) return obj async def delete(self, obj: ModelT) -> None: await self.session.delete(obj) await self.session.flush() async def count(self, **filters: Any) -> int: from sqlalchemy import func, select stmt = select(func.count()).select_from(self.model).filter_by(**filters) result = await self.session.execute(stmt) return result.scalar_one()