Skip to content

Commit 9132e66

Browse files
authored
Merge pull request #159 from igorbenav/blacklist-refresh-token
blacklist both tokens for logout
2 parents 4119051 + 1d4771d commit 9132e66

File tree

2 files changed

+60
-9
lines changed

2 files changed

+60
-9
lines changed

src/app/api/v1/logout.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
1-
from fastapi import APIRouter, Depends, Response
1+
from fastapi import APIRouter, Depends, Response, Cookie
22
from jose import JWTError
33
from sqlalchemy.ext.asyncio import AsyncSession
4+
from typing import Optional
45

56
from ...core.db.database import async_get_db
67
from ...core.exceptions.http_exceptions import UnauthorizedException
7-
from ...core.security import blacklist_token, oauth2_scheme
8+
from ...core.security import blacklist_tokens, oauth2_scheme
89

910
router = APIRouter(tags=["login"])
1011

1112

1213
@router.post("/logout")
1314
async def logout(
14-
response: Response, access_token: str = Depends(oauth2_scheme), db: AsyncSession = Depends(async_get_db)
15+
response: Response,
16+
access_token: str = Depends(oauth2_scheme),
17+
refresh_token: Optional[str] = Cookie(None, alias="refresh_token"),
18+
db: AsyncSession = Depends(async_get_db)
1519
) -> dict[str, str]:
1620
try:
17-
await blacklist_token(token=access_token, db=db)
21+
if not refresh_token:
22+
raise UnauthorizedException("Refresh token not found")
23+
24+
await blacklist_tokens(
25+
access_token=access_token,
26+
refresh_token=refresh_token,
27+
db=db
28+
)
1829
response.delete_cookie(key="refresh_token")
1930

2031
return {"message": "Logged out successfully"}

src/app/core/security.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
from datetime import UTC, datetime, timedelta
23
from typing import Any, Literal
34

@@ -11,6 +12,7 @@
1112
from .db.crud_token_blacklist import crud_token_blacklist
1213
from .schemas import TokenBlacklistCreate, TokenData
1314

15+
1416
SECRET_KEY = settings.SECRET_KEY
1517
ALGORITHM = settings.ALGORITHM
1618
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
@@ -19,6 +21,10 @@
1921
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/login")
2022

2123

24+
class TokenType(str, Enum):
25+
ACCESS = "access"
26+
REFRESH = "refresh"
27+
2228
async def verify_password(plain_password: str, hashed_password: str) -> bool:
2329
correct_password: bool = bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
2430
return correct_password
@@ -50,7 +56,7 @@ async def create_access_token(data: dict[str, Any], expires_delta: timedelta | N
5056
expire = datetime.now(UTC).replace(tzinfo=None) + expires_delta
5157
else:
5258
expire = datetime.now(UTC).replace(tzinfo=None) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
53-
to_encode.update({"exp": expire})
59+
to_encode.update({"exp": expire, "token_type": TokenType.ACCESS})
5460
encoded_jwt: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
5561
return encoded_jwt
5662

@@ -61,18 +67,20 @@ async def create_refresh_token(data: dict[str, Any], expires_delta: timedelta |
6167
expire = datetime.now(UTC).replace(tzinfo=None) + expires_delta
6268
else:
6369
expire = datetime.now(UTC).replace(tzinfo=None) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
64-
to_encode.update({"exp": expire})
70+
to_encode.update({"exp": expire, "token_type": TokenType.REFRESH})
6571
encoded_jwt: str = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
6672
return encoded_jwt
6773

6874

69-
async def verify_token(token: str, db: AsyncSession) -> TokenData | None:
75+
async def verify_token(token: str, expected_token_type: TokenType, db: AsyncSession) -> TokenData | None:
7076
"""Verify a JWT token and return TokenData if valid.
7177
7278
Parameters
7379
----------
7480
token: str
7581
The JWT token to be verified.
82+
expected_token_type: TokenType
83+
The expected type of token (access or refresh)
7684
db: AsyncSession
7785
Database session for performing database operations.
7886
@@ -88,15 +96,47 @@ async def verify_token(token: str, db: AsyncSession) -> TokenData | None:
8896
try:
8997
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
9098
username_or_email: str = payload.get("sub")
91-
if username_or_email is None:
99+
token_type: str = payload.get("token_type")
100+
101+
if username_or_email is None or token_type != expected_token_type:
92102
return None
103+
93104
return TokenData(username_or_email=username_or_email)
94105

95106
except JWTError:
96107
return None
97108

98109

110+
async def blacklist_tokens(access_token: str, refresh_token: str, db: AsyncSession) -> None:
111+
"""Blacklist both access and refresh tokens.
112+
113+
Parameters
114+
----------
115+
access_token: str
116+
The access token to blacklist
117+
refresh_token: str
118+
The refresh token to blacklist
119+
db: AsyncSession
120+
Database session for performing database operations.
121+
"""
122+
for token in [access_token, refresh_token]:
123+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
124+
expires_at = datetime.fromtimestamp(payload.get("exp"))
125+
await crud_token_blacklist.create(
126+
db,
127+
object=TokenBlacklistCreate(
128+
token=token,
129+
expires_at=expires_at
130+
)
131+
)
132+
99133
async def blacklist_token(token: str, db: AsyncSession) -> None:
100134
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
101135
expires_at = datetime.fromtimestamp(payload.get("exp"))
102-
await crud_token_blacklist.create(db, object=TokenBlacklistCreate(**{"token": token, "expires_at": expires_at}))
136+
await crud_token_blacklist.create(
137+
db,
138+
object=TokenBlacklistCreate(
139+
token=token,
140+
expires_at=expires_at
141+
)
142+
)

0 commit comments

Comments
 (0)