73 lines
2.0 KiB
Python
73 lines
2.0 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncGenerator
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from httpx import ASGITransport, AsyncClient
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
from sqlmodel import SQLModel
|
|
|
|
from app.api.deps import get_session
|
|
from app.core.config import settings
|
|
from app.core.constants import Role
|
|
from app.core.security import create_access_token, hash_password
|
|
from app.main import create_app
|
|
|
|
# Use a test database
|
|
TEST_MARIADB_DSN = settings.MARIADB_DSN.replace(
|
|
settings.MARIADB_DATABASE, f"{settings.MARIADB_DATABASE}_test"
|
|
)
|
|
|
|
test_engine = create_async_engine(TEST_MARIADB_DSN, echo=False)
|
|
TestSessionLocal = sessionmaker(bind=test_engine, class_=AsyncSession, expire_on_commit=False)
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
async with test_engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.create_all)
|
|
|
|
async with TestSessionLocal() as session:
|
|
yield session
|
|
|
|
async with test_engine.begin() as conn:
|
|
await conn.run_sync(SQLModel.metadata.drop_all)
|
|
|
|
await test_engine.dispose()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
|
|
app = create_app()
|
|
|
|
async def override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = override_get_session
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield ac
|
|
|
|
|
|
@pytest.fixture
|
|
def admin_token() -> str:
|
|
return create_access_token(subject=1, role=Role.SUPERADMIN)
|
|
|
|
|
|
@pytest.fixture
|
|
def user_token() -> str:
|
|
return create_access_token(subject=2, role=Role.USER)
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_headers(admin_token: str) -> dict[str, str]:
|
|
return {"Authorization": f"Bearer {admin_token}"}
|
|
|
|
|
|
@pytest.fixture
|
|
def user_headers(user_token: str) -> dict[str, str]:
|
|
return {"Authorization": f"Bearer {user_token}"}
|