55 lines
1.6 KiB
Python
Executable File
55 lines
1.6 KiB
Python
Executable File
"""Generic async repository. All concrete repos extend BaseRepository[T]."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from collections.abc import Sequence
|
|
from typing import Any, Generic, TypeVar
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from rehearsalhub.db.models import Base
|
|
|
|
ModelT = TypeVar("ModelT", bound=Base)
|
|
|
|
|
|
class BaseRepository(Generic[ModelT]):
|
|
model: type[ModelT]
|
|
|
|
def __init__(self, session: AsyncSession) -> None:
|
|
self.session = session
|
|
|
|
async def get_by_id(self, id: uuid.UUID) -> ModelT | None:
|
|
return await self.session.get(self.model, id)
|
|
|
|
async def list(self, **filters: Any) -> Sequence[ModelT]:
|
|
stmt = select(self.model).filter_by(**filters)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalars().all()
|
|
|
|
async def create(self, **kwargs: Any) -> ModelT:
|
|
obj = self.model(**kwargs)
|
|
self.session.add(obj)
|
|
await self.session.flush()
|
|
await self.session.refresh(obj)
|
|
return obj
|
|
|
|
async def update(self, obj: ModelT, **kwargs: Any) -> ModelT:
|
|
for key, value in kwargs.items():
|
|
setattr(obj, key, value)
|
|
await self.session.flush()
|
|
await self.session.refresh(obj)
|
|
return obj
|
|
|
|
async def delete(self, obj: ModelT) -> None:
|
|
await self.session.delete(obj)
|
|
await self.session.flush()
|
|
|
|
async def count(self, **filters: Any) -> int:
|
|
from sqlalchemy import func, select
|
|
|
|
stmt = select(func.count()).select_from(self.model).filter_by(**filters)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one()
|