fix OAuth authentication
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import APIRouter, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer, SecurityScopes
|
||||
from pydantic import BaseModel
|
||||
|
||||
from typing import Annotated
|
||||
from src.core.config import settings
|
||||
from src.core.log_conf import logger
|
||||
from src.core.security import authenticate_user, create_access_token
|
||||
@@ -39,3 +40,17 @@ def login(request: LoginRequest) -> Token:
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
@login_router.post("/token", tags=["login"], summary="Login for access token")
|
||||
async def login_for_access_token(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
) -> Token:
|
||||
user = authenticate_user(form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=400, detail="Incorrect username or password")
|
||||
access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.user_name, "scope": " ".join(form_data.scopes)},
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
@@ -13,43 +13,43 @@ from pydantic import ValidationError
|
||||
from src.core.config import settings
|
||||
from src.core.log_conf import logger
|
||||
from src.db.models.admin import Profile
|
||||
from src.db.repository.admin import get_profile, is_database_empty
|
||||
from src.db.repository.admin import get_profile_by_username, is_database_empty
|
||||
from src.db.session import SessionLocal
|
||||
from src.schema.admin import ProfileModel, TokenData
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
tokenUrl="/api/login/token",
|
||||
scopes={"me": "read", "admin": "read"},
|
||||
tokenUrl="/token",
|
||||
scopes={"me": "read", "admin": "read", "ROLE_ADMIN": "admin", "ROLE_MEDIA": "media", "ROLE_USER": "user"},
|
||||
)
|
||||
|
||||
|
||||
class OAuth2PasswordBearerWithCookie(OAuth2):
|
||||
def __init__(
|
||||
self,
|
||||
tokenUrl: str,
|
||||
scheme_name: Optional[str] = None,
|
||||
scopes: Optional[Dict[str, str]] = None,
|
||||
auto_error: bool = True,
|
||||
):
|
||||
if not scopes:
|
||||
scopes = {}
|
||||
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) # type: ignore
|
||||
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
|
||||
# class OAuth2PasswordBearerWithCookie(OAuth2):
|
||||
# def __init__(
|
||||
# self,
|
||||
# tokenUrl: str,
|
||||
# scheme_name: Optional[str] = None,
|
||||
# scopes: Optional[Dict[str, str]] = None,
|
||||
# auto_error: bool = True,
|
||||
# ):
|
||||
# if not scopes:
|
||||
# scopes = {}
|
||||
# flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes}) # type: ignore
|
||||
# super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
|
||||
|
||||
async def __call__(self, request: Request) -> Optional[str]:
|
||||
authorization: str = request.cookies.get("access_token") # changed to accept access token from httpOnly Cookie
|
||||
# async def __call__(self, request: Request) -> Optional[str]:
|
||||
# authorization: str = request.cookies.get("access_token") # changed to accept access token from httpOnly Cookie
|
||||
|
||||
scheme, param = get_authorization_scheme_param(authorization)
|
||||
if not authorization or scheme.lower() != "bearer":
|
||||
if self.auto_error:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
else:
|
||||
return None
|
||||
return param
|
||||
# scheme, param = get_authorization_scheme_param(authorization)
|
||||
# if not authorization or scheme.lower() != "bearer":
|
||||
# if self.auto_error:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
# detail="Not authenticated",
|
||||
# headers={"WWW-Authenticate": "Bearer"},
|
||||
# )
|
||||
# else:
|
||||
# return None
|
||||
# return param
|
||||
|
||||
|
||||
def authenticate_user(username: str, password: str) -> Optional[Profile]:
|
||||
@@ -110,10 +110,12 @@ async def get_current_user(
|
||||
token_scopes: List[str] = scope.split(" ")
|
||||
token_data = TokenData(scopes=token_scopes, username=username)
|
||||
except (JWTError, ValidationError):
|
||||
logger.info("Exception raised", exc_info=True)
|
||||
raise credentials_exception
|
||||
with SessionLocal() as db:
|
||||
user = get_profile(username=token_data.username, db=db) # type: ignore
|
||||
user = get_profile_by_username(username=token_data.username, db=db)
|
||||
if user is None:
|
||||
logger.info("user not found")
|
||||
raise credentials_exception
|
||||
for scope in security_scopes.scopes:
|
||||
if scope not in token_scopes:
|
||||
@@ -128,7 +130,7 @@ async def get_current_user(
|
||||
async def get_current_active_user(
|
||||
current_user: Annotated[Profile, Security(get_current_user, scopes=["me"])],
|
||||
) -> ProfileModel:
|
||||
if not current_user.enabled: # type: ignore
|
||||
if not current_user.enabled:
|
||||
raise HTTPException(status_code=400, detail="Inactive user")
|
||||
user_model = ProfileModel(
|
||||
username=current_user.user_name,
|
||||
|
||||
@@ -5,8 +5,12 @@ from sqlalchemy.orm import Session
|
||||
from src.db.models.admin import Profile
|
||||
|
||||
|
||||
def get_profile(username: AnyStr, db: Session) -> Optional[Profile]:
|
||||
profile = db.query(Profile).filter(Profile.email == username).first()
|
||||
def get_profile_by_username(username: AnyStr, db: Session) -> Optional[Profile]:
|
||||
profile = db.query(Profile).filter(Profile.user_name == username).first()
|
||||
return profile
|
||||
|
||||
def get_profile_by_email(email: AnyStr, db: Session) -> Optional[Profile]:
|
||||
profile = db.query(Profile).filter(Profile.email == email).first()
|
||||
return profile
|
||||
|
||||
def is_database_empty(db: Session) -> bool:
|
||||
|
||||
Reference in New Issue
Block a user