Skip to content

Commit a4c0e55

Browse files
authored
Merge pull request #56 from diogob/jwt-validate-exp
Jwt validate exp
2 parents 3fa52d4 + 4744615 commit a4c0e55

File tree

7 files changed

+86
-51
lines changed

7 files changed

+86
-51
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# CHANGELOG
22

3+
## Next release
4+
5+
- Send close connection once the JWT token expires (if channel is open with a token using the `exp` claim).
6+
37
## 0.6.1.0
48

59
- Add capability to unset `PGWS_ROOT_PATH` to disable static file serving.

app/Main.hs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ import qualified Hasql.Decoders as HD
1919
import qualified Hasql.Encoders as HE
2020
import qualified Hasql.Pool as P
2121
import Network.Wai.Application.Static
22+
import Data.Time.Clock (UTCTime, getCurrentTime)
23+
import Control.AutoUpdate ( defaultUpdateSettings
24+
, mkAutoUpdate
25+
, updateAction
26+
)
2227

2328
import Network.Wai (Application, responseLBS)
2429
import Network.HTTP.Types (status200)
@@ -61,11 +66,14 @@ main = do
6166

6267
pool <- P.acquire (configPool conf, 10, pgSettings)
6368
multi <- newHasqlBroadcaster listenChannel pgSettings
69+
getTime <- mkGetTime
6470

6571
runSettings appSettings $
66-
postgresWsMiddleware listenChannel (configJwtSecret conf) pool multi $
72+
postgresWsMiddleware getTime listenChannel (configJwtSecret conf) pool multi $
6773
logStdout $ maybe dummyApp staticApp' (configPath conf)
6874
where
75+
mkGetTime :: IO (IO UTCTime)
76+
mkGetTime = mkAutoUpdate defaultUpdateSettings {updateAction = getCurrentTime}
6977
staticApp' :: Text -> Application
7078
staticApp' = staticApp . defaultFileServerSettings . toS
7179
dummyApp :: Application

client-example/screen.css

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ h2 {
2121
}
2222

2323
div#main {
24-
width: 600px;
24+
width: 50%;
2525
margin: 0px auto 0px auto;
2626
padding: 0px;
2727
background-color: #fff;

postgres-websockets.cabal

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ library
4444
, stringsearch >= 0.3.6.6 && < 0.4
4545
, time >= 1.8.0.2 && < 1.9
4646
, contravariant >= 1.5.2 && < 1.6
47+
, alarmclock >= 0.7.0.2 && < 0.8
4748
default-language: Haskell2010
4849
default-extensions: OverloadedStrings, NoImplicitPrelude, LambdaCase
4950

@@ -67,8 +68,9 @@ executable postgres-websockets
6768
, wai >= 3.2 && < 4
6869
, wai-extra >= 3.0.29 && < 3.1
6970
, wai-app-static >= 3.1.7.1 && < 3.2
70-
, http-types
71+
, http-types >= 0.9
7172
, envparse >= 0.4.1
73+
, auto-update >= 0.1.6 && < 0.2
7274
default-language: Haskell2010
7375
default-extensions: OverloadedStrings, NoImplicitPrelude, QuasiQuotes
7476

@@ -82,18 +84,18 @@ test-suite postgres-websockets-test
8284
build-depends: base
8385
, protolude >= 0.2.3
8486
, postgres-websockets
85-
, containers
86-
, hspec
87-
, hspec-wai
88-
, hspec-wai-json
89-
, aeson
90-
, hasql
91-
, hasql-pool
87+
, hspec >= 2.7.1 && < 2.8
88+
, hspec-wai >= 0.9.2 && < 0.10
89+
, hspec-wai-json >= 0.9.2 && < 0.10
90+
, aeson >= 1.4.6.0 && < 1.5
91+
, hasql >= 0.19
92+
, hasql-pool >= 0.4
9293
, hasql-notifications >= 0.1.0.0 && < 0.2
93-
, http-types
94+
, http-types >= 0.9
95+
, time >= 1.8.0.2 && < 1.9
9496
, unordered-containers >= 0.2
95-
, wai-extra
96-
, stm
97+
, wai-extra >= 3.0.29 && < 3.1
98+
, stm >= 2.5.0.0 && < 2.6
9799
ghc-options: -Wall -threaded -rtsopts -with-rtsopts=-N
98100
default-language: Haskell2010
99101
default-extensions: OverloadedStrings, NoImplicitPrelude

src/PostgresWebsockets.hs

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ import qualified Data.ByteString.Char8 as BS
2222
import qualified Data.ByteString.Lazy as BL
2323
import qualified Data.HashMap.Strict as M
2424
import qualified Data.Text.Encoding.Error as T
25-
import Data.Time.Clock.POSIX (getPOSIXTime)
25+
import Data.Time.Clock (UTCTime)
26+
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
27+
import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
2628
import PostgresWebsockets.Broadcast (Multiplexer, onMessage)
2729
import qualified PostgresWebsockets.Broadcast as B
2830
import PostgresWebsockets.Claims
@@ -38,19 +40,21 @@ data Message = Message
3840
instance A.ToJSON Message
3941

4042
-- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware.
41-
postgresWsMiddleware :: Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
43+
postgresWsMiddleware :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> Wai.Application -> Wai.Application
4244
postgresWsMiddleware =
4345
WS.websocketsOr WS.defaultConnectionOptions `compose` wsApp
4446
where
45-
compose = (.) . (.) . (.) . (.)
47+
compose = (.) . (.) . (.) . (.) . (.)
4648

4749
-- private functions
50+
jwtExpirationStatusCode :: Word16
51+
jwtExpirationStatusCode = 3001
4852

4953
-- when the websocket is closed a ConnectionClosed Exception is triggered
5054
-- this kills all children and frees resources for us
51-
wsApp :: Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp
52-
wsApp dbChannel secret pool multi pendingConn =
53-
validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
55+
wsApp :: IO UTCTime -> Text -> ByteString -> H.Pool -> Multiplexer -> WS.ServerApp
56+
wsApp getTime dbChannel secret pool multi pendingConn =
57+
getTime >>= validateClaims requestChannel secret (toS jwtToken) >>= either rejectRequest forkSessions
5458
where
5559
hasRead m = m == ("r" :: ByteString) || m == ("rw" :: ByteString)
5660
hasWrite m = m == ("w" :: ByteString) || m == ("rw" :: ByteString)
@@ -68,27 +72,34 @@ wsApp dbChannel secret pool multi pendingConn =
6872
-- We should accept only after verifying JWT
6973
conn <- WS.acceptRequest pendingConn
7074
-- Fork a pinging thread to ensure browser connections stay alive
71-
WS.forkPingThread conn 30
75+
WS.withPingThread conn 30 (pure ()) $ do
76+
case M.lookup "exp" validClaims of
77+
Just (A.Number expClaim) -> do
78+
connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString))
79+
setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim)
80+
Just _ -> pure ()
81+
Nothing -> pure ()
7282

73-
when (hasRead mode) $
74-
onMessage multi ch $ WS.sendTextData conn . B.payload
83+
when (hasRead mode) $
84+
onMessage multi ch $ WS.sendTextData conn . B.payload
7585

76-
when (hasWrite mode) $
77-
let sendNotifications = void . H.notifyPool pool dbChannel . toS
78-
in notifySession validClaims (toS ch) conn sendNotifications
86+
when (hasWrite mode) $
87+
let sendNotifications = void . H.notifyPool pool dbChannel . toS
88+
in notifySession validClaims (toS ch) conn getTime sendNotifications
7989

80-
waitForever <- newEmptyMVar
81-
void $ takeMVar waitForever
90+
waitForever <- newEmptyMVar
91+
void $ takeMVar waitForever
8292

8393
-- Having both channel and claims as parameters seem redundant
8494
-- But it allows the function to ignore the claims structure and the source
8595
-- of the channel, so all claims decoding can be coded in the caller
8696
notifySession :: A.Object
8797
-> Text
8898
-> WS.Connection
99+
-> IO UTCTime
89100
-> (ByteString -> IO ())
90101
-> IO ()
91-
notifySession claimsToSend ch wsCon send =
102+
notifySession claimsToSend ch wsCon getTime send =
92103
withAsync (forever relayData) wait
93104
where
94105
relayData = jsonMsgWithTime >>= send
@@ -102,5 +113,5 @@ notifySession claimsToSend ch wsCon send =
102113
claimsWithChannel = M.insert "channel" (A.String ch) claimsToSend
103114
claimsWithTime :: IO (M.HashMap Text A.Value)
104115
claimsWithTime = do
105-
time <- getPOSIXTime
106-
return $ M.insert "message_delivered_at" (A.Number $ fromRational $ toRational time) claimsWithChannel
116+
time <- utcTimeToPOSIXSeconds <$> getTime
117+
return $ M.insert "message_delivered_at" (A.Number $ realToFrac time) claimsWithChannel

src/PostgresWebsockets/Claims.hs

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,27 @@ module PostgresWebsockets.Claims
1010
import Control.Lens
1111
import qualified Crypto.JOSE.Types as JOSE.Types
1212
import Crypto.JWT
13-
import Data.Aeson (Value (..), decode, toJSON)
1413
import qualified Data.HashMap.Strict as M
1514
import Protolude
15+
import Data.Time.Clock (UTCTime)
16+
import Data.String (String, fromString)
17+
import qualified Data.Aeson as JSON
18+
import qualified Data.Aeson.Types as JSON
1619

1720

18-
type Claims = M.HashMap Text Value
21+
type Claims = M.HashMap Text JSON.Value
1922
type ConnectionInfo = (ByteString, ByteString, Claims)
2023

2124
{-| Given a secret, a token and a timestamp it validates the claims and returns
2225
either an error message or a triple containing channel, mode and claims hashmap.
2326
-}
24-
validateClaims :: Maybe ByteString -> ByteString -> LByteString -> IO (Either Text ConnectionInfo)
25-
validateClaims requestChannel secret jwtToken =
27+
validateClaims :: Maybe ByteString -> ByteString -> LByteString -> UTCTime -> IO (Either Text ConnectionInfo)
28+
validateClaims requestChannel secret jwtToken time =
2629
runExceptT $ do
27-
cl <- liftIO $ jwtClaims (parseJWK secret) jwtToken
30+
cl <- liftIO $ jwtClaims time (parseJWK secret) jwtToken
2831
cl' <- case cl of
2932
JWTClaims c -> pure c
33+
JWTInvalid JWTExpired -> throwError "Token expired"
3034
_ -> throwError "Error"
3135
channel <- claimAsJSON requestChannel "channel" cl'
3236
mode <- claimAsJSON Nothing "mode" cl'
@@ -35,7 +39,7 @@ validateClaims requestChannel secret jwtToken =
3539
where
3640
claimAsJSON :: Maybe ByteString -> Text -> Claims -> ExceptT Text IO ByteString
3741
claimAsJSON defaultVal name cl = case M.lookup name cl of
38-
Just (String s) -> pure $ encodeUtf8 s
42+
Just (JSON.String s) -> pure $ encodeUtf8 s
3943
Just _ -> throwError "claim is not string value"
4044
Nothing -> nonExistingClaim defaultVal name
4145

@@ -53,20 +57,20 @@ validateClaims requestChannel secret jwtToken =
5357
-}
5458
data JWTAttempt = JWTInvalid JWTError
5559
| JWTMissingSecret
56-
| JWTClaims (M.HashMap Text Value)
60+
| JWTClaims (M.HashMap Text JSON.Value)
5761
deriving Eq
5862

5963
{-|
6064
Receives the JWT secret (from config) and a JWT and returns a map
6165
of JWT claims.
6266
-}
63-
jwtClaims :: JWK -> LByteString -> IO JWTAttempt
64-
jwtClaims _ "" = return $ JWTClaims M.empty
65-
jwtClaims secret payload = do
66-
let validation = defaultJWTValidationSettings (const True)
67+
jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt
68+
jwtClaims _ _ "" = return $ JWTClaims M.empty
69+
jwtClaims time jwk payload = do
70+
let config = defaultJWTValidationSettings (const True)
6771
eJwt <- runExceptT $ do
6872
jwt <- decodeCompact payload
69-
verifyClaims validation secret jwt
73+
verifyClaimsAt config jwk time jwt
7074
return $ case eJwt of
7175
Left e -> JWTInvalid e
7276
Right jwt -> JWTClaims . claims2map $ jwt
@@ -75,10 +79,10 @@ jwtClaims secret payload = do
7579
Internal helper used to turn JWT ClaimSet into something
7680
easier to work with
7781
-}
78-
claims2map :: ClaimsSet -> M.HashMap Text Value
79-
claims2map = val2map . toJSON
82+
claims2map :: ClaimsSet -> M.HashMap Text JSON.Value
83+
claims2map = val2map . JSON.toJSON
8084
where
81-
val2map (Object o) = o
85+
val2map (JSON.Object o) = o
8286
val2map _ = M.empty
8387

8488
{-|
@@ -96,4 +100,4 @@ hs256jwk key =
96100

97101
parseJWK :: ByteString -> JWK
98102
parseJWK str =
99-
fromMaybe (hs256jwk str) (decode (toS str) :: Maybe JWK)
103+
fromMaybe (hs256jwk str) (JSON.decode (toS str) :: Maybe JWK)

test/ClaimsSpec.hs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@ import Protolude
55
import qualified Data.HashMap.Strict as M
66
import Test.Hspec
77
import Data.Aeson (Value (..) )
8-
8+
import Data.Time.Clock
99
import PostgresWebsockets.Claims
1010

1111
spec :: Spec
1212
spec =
13-
describe "validate claims"
14-
$ it "should succeed using a matching token"
15-
$ validateClaims Nothing "reallyreallyreallyreallyverysafe"
16-
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58"
13+
describe "validate claims" $ do
14+
it "should invalidate an expired token" $ do
15+
time <- getCurrentTime
16+
validateClaims Nothing "reallyreallyreallyreallyverysafe"
17+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0IiwiZXhwIjoxfQ.4rDYiMZFR2WHB7Eq4HMdvDP_BQZVtHIfyJgy0NshbHY" time
18+
`shouldReturn` Left "Token expired"
19+
it "should succeed using a matching token" $ do
20+
time <- getCurrentTime
21+
validateClaims Nothing "reallyreallyreallyreallyverysafe"
22+
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" time
1723
`shouldReturn` Right ("test", "r", M.fromList[("mode",String "r"),("channel",String "test")])

0 commit comments

Comments
 (0)