from __future__ import annotations from typing import Generic, Sequence, TypeVar from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import SQLModel T = TypeVar("T", bound=SQLModel) class BaseRepository(Generic[T]): def __init__(self, model: type[T], session: AsyncSession): self.model = model self.session = session async def get_by_id(self, id: int) -> T | None: return await self.session.get(self.model, id) async def get_all( self, *, skip: int = 0, limit: int = 100, filters: dict | None = None ) -> Sequence[T]: stmt = select(self.model) if filters: for key, value in filters.items(): if hasattr(self.model, key): stmt = stmt.where(getattr(self.model, key) == value) stmt = stmt.offset(skip).limit(limit) result = await self.session.execute(stmt) return result.scalars().all() async def count(self, filters: dict | None = None) -> int: stmt = select(func.count()).select_from(self.model) if filters: for key, value in filters.items(): if hasattr(self.model, key): stmt = stmt.where(getattr(self.model, key) == value) result = await self.session.execute(stmt) return result.scalar_one() async def create(self, obj: T) -> T: self.session.add(obj) await self.session.flush() await self.session.refresh(obj) return obj async def update(self, obj: T, data: dict) -> T: for key, value in data.items(): if value is not None and hasattr(obj, key): setattr(obj, key, value) self.session.add(obj) await self.session.flush() await self.session.refresh(obj) return obj async def delete(self, obj: T) -> None: await self.session.delete(obj) await self.session.flush()