6310377d84
Remove endpoints api/login/token and api/login/profile --------- Co-authored-by: Thomas Peetz <thomas.peetz@cimt-ag.de> Reviewed-on: #89
76 lines
1.9 KiB
Python
76 lines
1.9 KiB
Python
|
|
import os
|
|
import sys
|
|
from typing import Any
|
|
from typing import Generator
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from src.apis.base import api_router
|
|
from src.db.models.base import Base
|
|
from src.db.session import get_db
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
# this is to include backend dir in sys.path so that we can import from db,main.py
|
|
|
|
|
|
|
|
def start_application():
|
|
app = FastAPI()
|
|
app.include_router(api_router)
|
|
return app
|
|
|
|
|
|
SQLALCHEMY_DATABASE_URL = "sqlite:///./test_db.db"
|
|
engine = create_engine(
|
|
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
|
|
)
|
|
# Use connect_args parameter only with sqlite
|
|
SessionTesting = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def app() -> Generator[FastAPI, Any, None]:
|
|
"""
|
|
Create a fresh database on each test case.
|
|
"""
|
|
Base.metadata.create_all(engine) # Create the tables.
|
|
_app = start_application()
|
|
yield _app
|
|
Base.metadata.drop_all(engine)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def db_session(app: FastAPI) -> Generator[Session, Any, None]:
|
|
connection = engine.connect()
|
|
transaction = connection.begin()
|
|
session = SessionTesting(bind=connection)
|
|
yield session # use the session in tests.
|
|
session.close()
|
|
transaction.rollback()
|
|
connection.close()
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def client(
|
|
app: FastAPI, db_session: Session
|
|
) -> Generator[TestClient, Any, None]:
|
|
"""
|
|
Create a new FastAPI TestClient that uses the `db_session` fixture to override
|
|
the `get_db` dependency that is injected into routes.
|
|
"""
|
|
|
|
def _get_test_db():
|
|
try:
|
|
yield db_session
|
|
finally:
|
|
pass
|
|
|
|
app.dependency_overrides[get_db] = _get_test_db
|
|
with TestClient(app) as client:
|
|
yield client
|