from __future__ import annotations from datetime import datetime from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.mariadb.auth import OAuthAccount, RefreshToken from app.repositories.base import BaseRepository class AuthRepository(BaseRepository[RefreshToken]): def __init__(self, session: AsyncSession): super().__init__(RefreshToken, session) async def get_by_token(self, token: str) -> RefreshToken | None: stmt = select(RefreshToken).where( RefreshToken.token == token, RefreshToken.is_revoked == False, # noqa: E712 RefreshToken.expires_at > datetime.utcnow(), ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def revoke_all_for_user(self, user_id: int) -> None: stmt = select(RefreshToken).where( RefreshToken.user_id == user_id, RefreshToken.is_revoked == False, # noqa: E712 ) result = await self.session.execute(stmt) for token in result.scalars().all(): token.is_revoked = True self.session.add(token) await self.session.flush() async def get_oauth_account( self, provider: str, provider_user_id: str ) -> OAuthAccount | None: stmt = select(OAuthAccount).where( OAuthAccount.provider == provider, OAuthAccount.provider_user_id == provider_user_id, ) result = await self.session.execute(stmt) return result.scalar_one_or_none() async def create_oauth_account(self, account: OAuthAccount) -> OAuthAccount: self.session.add(account) await self.session.flush() await self.session.refresh(account) return account