Files
rehearshalhub/api/src/rehearsalhub/repositories/base.py
Mistral Vibe 411414b9c1 Fixing build
2026-04-10 10:23:32 +02:00

55 lines
1.6 KiB
Python
Executable File

"""Generic async repository. All concrete repos extend BaseRepository[T]."""
from __future__ import annotations
import uuid
from collections.abc import Sequence
from typing import Any, Generic, TypeVar
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from rehearsalhub.db.models import Base
ModelT = TypeVar("ModelT", bound=Base)
class BaseRepository(Generic[ModelT]):
model: type[ModelT]
def __init__(self, session: AsyncSession) -> None:
self.session = session
async def get_by_id(self, id: uuid.UUID) -> ModelT | None:
return await self.session.get(self.model, id)
async def list(self, **filters: Any) -> Sequence[ModelT]:
stmt = select(self.model).filter_by(**filters)
result = await self.session.execute(stmt)
return result.scalars().all()
async def create(self, **kwargs: Any) -> ModelT:
obj = self.model(**kwargs)
self.session.add(obj)
await self.session.flush()
await self.session.refresh(obj)
return obj
async def update(self, obj: ModelT, **kwargs: Any) -> ModelT:
for key, value in kwargs.items():
setattr(obj, key, value)
await self.session.flush()
await self.session.refresh(obj)
return obj
async def delete(self, obj: ModelT) -> None:
await self.session.delete(obj)
await self.session.flush()
async def count(self, **filters: Any) -> int:
from sqlalchemy import func, select
stmt = select(func.count()).select_from(self.model).filter_by(**filters)
result = await self.session.execute(stmt)
return result.scalar_one()