1+ from enum import Enum
12from datetime import UTC , datetime , timedelta
23from typing import Any , Literal
34
1112from .db .crud_token_blacklist import crud_token_blacklist
1213from .schemas import TokenBlacklistCreate , TokenData
1314
15+
1416SECRET_KEY = settings .SECRET_KEY
1517ALGORITHM = settings .ALGORITHM
1618ACCESS_TOKEN_EXPIRE_MINUTES = settings .ACCESS_TOKEN_EXPIRE_MINUTES
1921oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "/api/v1/login" )
2022
2123
24+ class TokenType (str , Enum ):
25+ ACCESS = "access"
26+ REFRESH = "refresh"
27+
2228async def verify_password (plain_password : str , hashed_password : str ) -> bool :
2329 correct_password : bool = bcrypt .checkpw (plain_password .encode (), hashed_password .encode ())
2430 return correct_password
@@ -50,7 +56,7 @@ async def create_access_token(data: dict[str, Any], expires_delta: timedelta | N
5056 expire = datetime .now (UTC ).replace (tzinfo = None ) + expires_delta
5157 else :
5258 expire = datetime .now (UTC ).replace (tzinfo = None ) + timedelta (minutes = ACCESS_TOKEN_EXPIRE_MINUTES )
53- to_encode .update ({"exp" : expire })
59+ to_encode .update ({"exp" : expire , "token_type" : TokenType . ACCESS })
5460 encoded_jwt : str = jwt .encode (to_encode , SECRET_KEY , algorithm = ALGORITHM )
5561 return encoded_jwt
5662
@@ -61,18 +67,20 @@ async def create_refresh_token(data: dict[str, Any], expires_delta: timedelta |
6167 expire = datetime .now (UTC ).replace (tzinfo = None ) + expires_delta
6268 else :
6369 expire = datetime .now (UTC ).replace (tzinfo = None ) + timedelta (days = REFRESH_TOKEN_EXPIRE_DAYS )
64- to_encode .update ({"exp" : expire })
70+ to_encode .update ({"exp" : expire , "token_type" : TokenType . REFRESH })
6571 encoded_jwt : str = jwt .encode (to_encode , SECRET_KEY , algorithm = ALGORITHM )
6672 return encoded_jwt
6773
6874
69- async def verify_token (token : str , db : AsyncSession ) -> TokenData | None :
75+ async def verify_token (token : str , expected_token_type : TokenType , db : AsyncSession ) -> TokenData | None :
7076 """Verify a JWT token and return TokenData if valid.
7177
7278 Parameters
7379 ----------
7480 token: str
7581 The JWT token to be verified.
82+ expected_token_type: TokenType
83+ The expected type of token (access or refresh)
7684 db: AsyncSession
7785 Database session for performing database operations.
7886
@@ -88,15 +96,47 @@ async def verify_token(token: str, db: AsyncSession) -> TokenData | None:
8896 try :
8997 payload = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
9098 username_or_email : str = payload .get ("sub" )
91- if username_or_email is None :
99+ token_type : str = payload .get ("token_type" )
100+
101+ if username_or_email is None or token_type != expected_token_type :
92102 return None
103+
93104 return TokenData (username_or_email = username_or_email )
94105
95106 except JWTError :
96107 return None
97108
98109
110+ async def blacklist_tokens (access_token : str , refresh_token : str , db : AsyncSession ) -> None :
111+ """Blacklist both access and refresh tokens.
112+
113+ Parameters
114+ ----------
115+ access_token: str
116+ The access token to blacklist
117+ refresh_token: str
118+ The refresh token to blacklist
119+ db: AsyncSession
120+ Database session for performing database operations.
121+ """
122+ for token in [access_token , refresh_token ]:
123+ payload = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
124+ expires_at = datetime .fromtimestamp (payload .get ("exp" ))
125+ await crud_token_blacklist .create (
126+ db ,
127+ object = TokenBlacklistCreate (
128+ token = token ,
129+ expires_at = expires_at
130+ )
131+ )
132+
99133async def blacklist_token (token : str , db : AsyncSession ) -> None :
100134 payload = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
101135 expires_at = datetime .fromtimestamp (payload .get ("exp" ))
102- await crud_token_blacklist .create (db , object = TokenBlacklistCreate (** {"token" : token , "expires_at" : expires_at }))
136+ await crud_token_blacklist .create (
137+ db ,
138+ object = TokenBlacklistCreate (
139+ token = token ,
140+ expires_at = expires_at
141+ )
142+ )
0 commit comments