@@ -815,13 +815,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
815815
816816 jwt_secret = "secret"
817817 jwt_algorithm = "HS256"
818+ base_config = {
819+ "enabled" : True ,
820+ "secret" : jwt_secret ,
821+ "algorithm" : jwt_algorithm ,
822+ }
818823
819- def make_homeserver (self , reactor , clock ):
820- self .hs = self .setup_test_homeserver ()
821- self .hs .config .jwt .jwt_enabled = True
822- self .hs .config .jwt .jwt_secret = self .jwt_secret
823- self .hs .config .jwt .jwt_algorithm = self .jwt_algorithm
824- return self .hs
824+ def default_config (self ):
825+ config = super ().default_config ()
826+
827+ # If jwt_config has been defined (eg via @override_config), don't replace it.
828+ if config .get ("jwt_config" ) is None :
829+ config ["jwt_config" ] = self .base_config
830+
831+ return config
825832
826833 def jwt_encode (self , payload : Dict [str , Any ], secret : str = jwt_secret ) -> str :
827834 # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
@@ -879,16 +886,7 @@ def test_login_no_sub(self):
879886 self .assertEqual (channel .json_body ["errcode" ], "M_FORBIDDEN" )
880887 self .assertEqual (channel .json_body ["error" ], "Invalid JWT" )
881888
882- @override_config (
883- {
884- "jwt_config" : {
885- "jwt_enabled" : True ,
886- "secret" : jwt_secret ,
887- "algorithm" : jwt_algorithm ,
888- "issuer" : "test-issuer" ,
889- }
890- }
891- )
889+ @override_config ({"jwt_config" : {** base_config , "issuer" : "test-issuer" }})
892890 def test_login_iss (self ):
893891 """Test validating the issuer claim."""
894892 # A valid issuer.
@@ -919,16 +917,7 @@ def test_login_iss_no_config(self):
919917 self .assertEqual (channel .result ["code" ], b"200" , channel .result )
920918 self .assertEqual (channel .json_body ["user_id" ], "@kermit:test" )
921919
922- @override_config (
923- {
924- "jwt_config" : {
925- "jwt_enabled" : True ,
926- "secret" : jwt_secret ,
927- "algorithm" : jwt_algorithm ,
928- "audiences" : ["test-audience" ],
929- }
930- }
931- )
920+ @override_config ({"jwt_config" : {** base_config , "audiences" : ["test-audience" ]}})
932921 def test_login_aud (self ):
933922 """Test validating the audience claim."""
934923 # A valid audience.
@@ -962,6 +951,19 @@ def test_login_aud_no_config(self):
962951 channel .json_body ["error" ], "JWT validation failed: Invalid audience"
963952 )
964953
954+ def test_login_default_sub (self ):
955+ """Test reading user ID from the default subject claim."""
956+ channel = self .jwt_login ({"sub" : "kermit" })
957+ self .assertEqual (channel .result ["code" ], b"200" , channel .result )
958+ self .assertEqual (channel .json_body ["user_id" ], "@kermit:test" )
959+
960+ @override_config ({"jwt_config" : {** base_config , "subject_claim" : "username" }})
961+ def test_login_custom_sub (self ):
962+ """Test reading user ID from a custom subject claim."""
963+ channel = self .jwt_login ({"username" : "frog" })
964+ self .assertEqual (channel .result ["code" ], b"200" , channel .result )
965+ self .assertEqual (channel .json_body ["user_id" ], "@frog:test" )
966+
965967 def test_login_no_token (self ):
966968 params = {"type" : "org.matrix.login.jwt" }
967969 channel = self .make_request (b"POST" , LOGIN_URL , params )
@@ -1024,12 +1026,14 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
10241026 ]
10251027 )
10261028
1027- def make_homeserver (self , reactor , clock ):
1028- self .hs = self .setup_test_homeserver ()
1029- self .hs .config .jwt .jwt_enabled = True
1030- self .hs .config .jwt .jwt_secret = self .jwt_pubkey
1031- self .hs .config .jwt .jwt_algorithm = "RS256"
1032- return self .hs
1029+ def default_config (self ):
1030+ config = super ().default_config ()
1031+ config ["jwt_config" ] = {
1032+ "enabled" : True ,
1033+ "secret" : self .jwt_pubkey ,
1034+ "algorithm" : "RS256" ,
1035+ }
1036+ return config
10331037
10341038 def jwt_encode (self , payload : Dict [str , Any ], secret : str = jwt_privatekey ) -> str :
10351039 # PyJWT 2.0.0 changed the return type of jwt.encode from bytes to str.
0 commit comments