51 lines
1.8 KiB
Python
51 lines
1.8 KiB
Python
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.mariadb.auth import OAuthAccount, RefreshToken
|
|
from app.repositories.base import BaseRepository
|
|
|
|
|
|
class AuthRepository(BaseRepository[RefreshToken]):
|
|
def __init__(self, session: AsyncSession):
|
|
super().__init__(RefreshToken, session)
|
|
|
|
async def get_by_token(self, token: str) -> RefreshToken | None:
|
|
stmt = select(RefreshToken).where(
|
|
RefreshToken.token == token,
|
|
RefreshToken.is_revoked == False, # noqa: E712
|
|
RefreshToken.expires_at > datetime.utcnow(),
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def revoke_all_for_user(self, user_id: int) -> None:
|
|
stmt = select(RefreshToken).where(
|
|
RefreshToken.user_id == user_id,
|
|
RefreshToken.is_revoked == False, # noqa: E712
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
for token in result.scalars().all():
|
|
token.is_revoked = True
|
|
self.session.add(token)
|
|
await self.session.flush()
|
|
|
|
async def get_oauth_account(
|
|
self, provider: str, provider_user_id: str
|
|
) -> OAuthAccount | None:
|
|
stmt = select(OAuthAccount).where(
|
|
OAuthAccount.provider == provider,
|
|
OAuthAccount.provider_user_id == provider_user_id,
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def create_oauth_account(self, account: OAuthAccount) -> OAuthAccount:
|
|
self.session.add(account)
|
|
await self.session.flush()
|
|
await self.session.refresh(account)
|
|
return account
|