diff --git a/Setup.hs b/Setup.hs index 9a994af..e8ef27d 100644 --- a/Setup.hs +++ b/Setup.hs @@ -1,2 +1,3 @@ import Distribution.Simple + main = defaultMain diff --git a/app/Main.hs b/app/Main.hs index c831e20..d98eb29 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -1,19 +1,19 @@ module Main where -import Protolude +import APrelude import PostgresWebsockets - import System.IO (BufferMode (..), hSetBuffering) main :: IO () main = do hSetBuffering stdout LineBuffering - hSetBuffering stdin LineBuffering + hSetBuffering stdin LineBuffering hSetBuffering stderr NoBuffering - putStrLn $ ("postgres-websockets " :: Text) - <> prettyVersion - <> " / Connects websockets to PostgreSQL asynchronous notifications." + putStrLn $ + "postgres-websockets " + <> unpack prettyVersion + <> " / Connects websockets to PostgreSQL asynchronous notifications." conf <- loadConfig - void $ serve conf \ No newline at end of file + void $ serve conf diff --git a/postgres-websockets.cabal b/postgres-websockets.cabal index 95868fa..536d3a9 100644 --- a/postgres-websockets.cabal +++ b/postgres-websockets.cabal @@ -18,7 +18,7 @@ common warnings common language default-language: Haskell2010 - default-extensions: OverloadedStrings, NoImplicitPrelude, LambdaCase, RecordWildCards, QuasiQuotes + default-extensions: OverloadedStrings, LambdaCase, RecordWildCards, QuasiQuotes library import: warnings @@ -29,6 +29,7 @@ library , PostgresWebsockets.HasqlBroadcast , PostgresWebsockets.Claims , PostgresWebsockets.Config + , APrelude autogen-modules: Paths_postgres_websockets other-modules: Paths_postgres_websockets @@ -49,8 +50,9 @@ library , http-types >= 0.12.3 && < 0.13 , jose >= 0.11 && < 0.12 , lens >= 5.2.3 && < 5.4 + , mtl >=2.3.1 && <2.4 + , async >=2.2.5 && <2.3 , postgresql-libpq >= 0.10.0 && < 0.12 - , protolude >= 0.2.3 && < 0.4 , retry >= 0.8.1.0 && < 0.10 , stm >= 2.5.0.0 && < 2.6 , stm-containers >= 1.1.0.2 && < 1.3 @@ -76,7 +78,6 @@ executable postgres-websockets ghc-options: -threaded -rtsopts -with-rtsopts=-N build-depends: base >= 4.7 && < 5 , postgres-websockets - , protolude >= 0.2.3 && < 0.4 default-language: Haskell2010 test-suite postgres-websockets-test @@ -90,7 +91,6 @@ test-suite postgres-websockets-test , HasqlBroadcastSpec , ServerSpec build-depends: base - , protolude >= 0.2.3 && < 0.4 , postgres-websockets , hspec >= 2.7.1 && < 2.12 , aeson >= 2.0 && < 2.3 diff --git a/src/APrelude.hs b/src/APrelude.hs new file mode 100644 index 0000000..7c29c83 --- /dev/null +++ b/src/APrelude.hs @@ -0,0 +1,104 @@ +module APrelude + ( Text, + ByteString, + LByteString, + Generic, + fromMaybe, + putErrLn, + fromRight, + isJust, + decodeUtf8, + encodeUtf8, + MVar, + readMVar, + swapMVar, + newMVar, + STM, + atomically, + ThreadId, + forkFinally, + forkIO, + killThread, + threadDelay, + (>=>), + when, + forever, + void, + panic, + SomeException, + throwError, + liftIO, + runExceptT, + unpack, + pack, + showText, + showBS, + LBS.fromStrict, + stdin, + stdout, + stderr, + hPutStrLn, + Word16, + forM, + forM_, + takeMVar, + newEmptyMVar, + wait, + headDef, + tailSafe, + withAsync, + putMVar, + die, + myThreadId, + replicateM, + bracket, + ) +where + +import Control.Concurrent (ThreadId, forkFinally, forkIO, killThread, myThreadId, threadDelay) +import Control.Concurrent.Async (wait, withAsync) +import Control.Concurrent.MVar (MVar, newEmptyMVar, newMVar, putMVar, readMVar, swapMVar, takeMVar) +import Control.Concurrent.STM (STM, atomically) +import Control.Exception (Exception, SomeException, bracket, throw) +import Control.Monad (forM, forM_, forever, replicateM, void, when, (>=>)) +import Control.Monad.Error.Class (throwError) +import Control.Monad.Except (runExceptT) +import Control.Monad.IO.Class (liftIO) +import Data.ByteString (ByteString) +import qualified Data.ByteString.Char8 as BS +import qualified Data.ByteString.Lazy as LBS +import Data.Either (fromRight) +import Data.Maybe (fromMaybe, isJust, listToMaybe) +import Data.Text (Text, pack, unpack) +import qualified Data.Text as T +import Data.Text.Encoding +import Data.Word (Word16) +import GHC.Generics (Generic) +import System.Exit (die) +import System.IO (hPutStrLn, stderr, stdin, stdout) + +showBS :: (Show a) => a -> BS.ByteString +showBS = BS.pack . show + +showText :: (Show a) => a -> Text +showText = T.pack . show + +type LByteString = LBS.ByteString + +-- | Uncatchable exceptions thrown and never caught. +newtype FatalError = FatalError {fatalErrorMessage :: Text} + deriving (Show) + +instance Exception FatalError + +panic :: Text -> a +panic a = throw (FatalError a) + +putErrLn :: Text -> IO () +putErrLn = hPutStrLn stderr . unpack + +headDef :: a -> [a] -> a +headDef def = fromMaybe def . listToMaybe + +tailSafe :: [a] -> [a] +tailSafe = drop 1 diff --git a/src/PostgresWebsockets.hs b/src/PostgresWebsockets.hs index f42e887..26a6306 100644 --- a/src/PostgresWebsockets.hs +++ b/src/PostgresWebsockets.hs @@ -1,16 +1,16 @@ -{-| -Module : PostgresWebsockets -Description : PostgresWebsockets main library interface. - -These are all function necessary to configure and start the server. --} +-- | +-- Module : PostgresWebsockets +-- Description : PostgresWebsockets main library interface. +-- +-- These are all function necessary to configure and start the server. module PostgresWebsockets - ( prettyVersion - , loadConfig - , serve - , postgresWsMiddleware - ) where + ( prettyVersion, + loadConfig, + serve, + postgresWsMiddleware, + ) +where -import PostgresWebsockets.Middleware ( postgresWsMiddleware ) -import PostgresWebsockets.Server ( serve ) -import PostgresWebsockets.Config ( prettyVersion, loadConfig ) +import PostgresWebsockets.Config (loadConfig, prettyVersion) +import PostgresWebsockets.Middleware (postgresWsMiddleware) +import PostgresWebsockets.Server (serve) diff --git a/src/PostgresWebsockets/Broadcast.hs b/src/PostgresWebsockets/Broadcast.hs index 8fac2db..99262a8 100644 --- a/src/PostgresWebsockets/Broadcast.hs +++ b/src/PostgresWebsockets/Broadcast.hs @@ -26,11 +26,10 @@ module PostgresWebsockets.Broadcast ) where +import APrelude import Control.Concurrent.STM.TChan import Control.Concurrent.STM.TQueue import qualified Data.Aeson as A -import Protolude hiding (toS) -import Protolude.Conv (toS) import qualified StmContainers.Map as M data Message = Message @@ -63,7 +62,7 @@ instance A.ToJSON MultiplexerSnapshot -- | Given a multiplexer derive a type that can be printed for debugging or logging purposes takeSnapshot :: Multiplexer -> IO MultiplexerSnapshot takeSnapshot multi = - MultiplexerSnapshot <$> size <*> e <*> thread + MultiplexerSnapshot <$> size <*> e <*> (pack <$> thread) where size = atomically $ M.size $ channels multi thread = show <$> readMVar (producerThreadId multi) @@ -113,7 +112,7 @@ superviseMultiplexer multi msInterval shouldRestart = do new <- reopenProducer multi void $ swapMVar (producerThreadId multi) new snapAfter <- takeSnapshot multi - putStrLn $ + print $ "Restarting producer. Multiplexer updated: " <> A.encode snapBefore <> " -> " @@ -142,7 +141,7 @@ onMessage multi chan action = do where disposeListener _ = atomically $ do mC <- M.lookup chan (channels multi) - let c = fromMaybe (panic $ "trying to remove listener from non existing channel: " <> toS chan) mC + let c = fromMaybe (panic $ "trying to remove listener from non existing channel: " <> chan) mC M.delete chan (channels multi) when (listeners c - 1 > 0) $ M.insert Channel {broadcast = broadcast c, listeners = listeners c - 1} chan (channels multi) diff --git a/src/PostgresWebsockets/Claims.hs b/src/PostgresWebsockets/Claims.hs index 9310c07..cd9fee5 100644 --- a/src/PostgresWebsockets/Claims.hs +++ b/src/PostgresWebsockets/Claims.hs @@ -1,54 +1,56 @@ -{-| -Module : PostgresWebsockets.Claims -Description : Parse and validate JWT to open postgres-websockets channels. - -This module provides the JWT claims validation. Since websockets and -listening connections in the database tend to be resource intensive -(not to mention stateful) we need claims authorizing a specific channel and -mode of operation. --} +-- | +-- Module : PostgresWebsockets.Claims +-- Description : Parse and validate JWT to open postgres-websockets channels. +-- +-- This module provides the JWT claims validation. Since websockets and +-- listening connections in the database tend to be resource intensive +-- (not to mention stateful) we need claims authorizing a specific channel and +-- mode of operation. module PostgresWebsockets.Claims - ( ConnectionInfo,validateClaims - ) where + ( ConnectionInfo, + validateClaims, + ) +where -import Protolude hiding (toS) -import Protolude.Conv +import APrelude import Control.Lens -import Crypto.JWT -import Data.List -import Data.Time.Clock (UTCTime) import qualified Crypto.JOSE.Types as JOSE.Types +import Crypto.JWT import qualified Data.Aeson as JSON -import qualified Data.Aeson.KeyMap as JSON import qualified Data.Aeson.Key as Key +import qualified Data.Aeson.KeyMap as JSON +import Data.List +import Data.Time.Clock (UTCTime) type Claims = JSON.KeyMap JSON.Value + type ConnectionInfo = ([Text], Text, Claims) -{-| Given a secret, a token and a timestamp it validates the claims and returns - either an error message or a triple containing channel, mode and claims KeyMap. --} -validateClaims - :: Maybe Text - -> ByteString - -> LByteString - -> UTCTime - -> IO (Either Text ConnectionInfo) +-- | Given a secret, a token and a timestamp it validates the claims and returns +-- either an error message or a triple containing channel, mode and claims KeyMap. +validateClaims :: + Maybe Text -> + ByteString -> + LByteString -> + UTCTime -> + IO (Either Text ConnectionInfo) validateClaims requestChannel secret jwtToken time = runExceptT $ do - cl <- liftIO $ jwtClaims time (parseJWK secret) jwtToken + cl <- liftIO $ jwtClaims time (parseJWK secret) jwtToken cl' <- case cl of - JWTClaims c -> pure c + JWTClaims c -> pure c JWTInvalid JWTExpired -> throwError "Token expired" - JWTInvalid err -> throwError $ "Error: " <> show err - channels <- let chs = claimAsJSONList "channels" cl' in pure $ case claimAsJSON "channel" cl' of - Just c -> case chs of - Just cs -> nub (c : cs) - Nothing -> [c] - Nothing -> fromMaybe [] chs + JWTInvalid err -> throwError $ "Error: " <> showText err + channels <- + let chs = claimAsJSONList "channels" cl' + in pure $ case claimAsJSON "channel" cl' of + Just c -> case chs of + Just cs -> nub (c : cs) + Nothing -> [c] + Nothing -> fromMaybe [] chs mode <- let md = claimAsJSON "mode" cl' - in case md of - Just m -> pure m + in case md of + Just m -> pure m Nothing -> throwError "Missing mode" requestedAllowedChannels <- case (requestChannel, length channels) of (Just rc, 0) -> pure [rc] @@ -56,32 +58,30 @@ validateClaims requestChannel secret jwtToken time = runExceptT $ do (Nothing, _) -> pure channels validChannels <- if null requestedAllowedChannels then throwError "No allowed channels" else pure requestedAllowedChannels pure (validChannels, mode, cl') + where + claimAsJSON :: Text -> Claims -> Maybe Text + claimAsJSON name cl = case JSON.lookup (Key.fromText name) cl of + Just (JSON.String s) -> Just s + _ -> Nothing - where - claimAsJSON :: Text -> Claims -> Maybe Text - claimAsJSON name cl = case JSON.lookup (Key.fromText name) cl of - Just (JSON.String s) -> Just s - _ -> Nothing - - claimAsJSONList :: Text -> Claims -> Maybe [Text] - claimAsJSONList name cl = case JSON.lookup (Key.fromText name) cl of - Just channelsJson -> - case JSON.fromJSON channelsJson :: JSON.Result [Text] of - JSON.Success channelsList -> Just channelsList - _ -> Nothing - Nothing -> Nothing + claimAsJSONList :: Text -> Claims -> Maybe [Text] + claimAsJSONList name cl = case JSON.lookup (Key.fromText name) cl of + Just channelsJson -> + case JSON.fromJSON channelsJson :: JSON.Result [Text] of + JSON.Success channelsList -> Just channelsList + _ -> Nothing + Nothing -> Nothing -{-| - Possible situations encountered with client JWTs --} -data JWTAttempt = JWTInvalid JWTError - | JWTClaims (JSON.KeyMap JSON.Value) - deriving Eq +-- | +-- Possible situations encountered with client JWTs +data JWTAttempt + = JWTInvalid JWTError + | JWTClaims (JSON.KeyMap JSON.Value) + deriving (Eq) -{-| - Receives the JWT secret (from config) and a JWT and returns a map - of JWT claims. --} +-- | +-- Receives the JWT secret (from config) and a JWT and returns a map +-- of JWT claims. jwtClaims :: UTCTime -> JWK -> LByteString -> IO JWTAttempt jwtClaims _ _ "" = return $ JWTClaims JSON.empty jwtClaims time jwk' payload = do @@ -90,32 +90,30 @@ jwtClaims time jwk' payload = do jwt <- decodeCompact payload verifyClaimsAt config jwk' time jwt return $ case eJwt of - Left e -> JWTInvalid e + Left e -> JWTInvalid e Right jwt -> JWTClaims . claims2map $ jwt -{-| - Internal helper used to turn JWT ClaimSet into something - easier to work with --} +-- | +-- Internal helper used to turn JWT ClaimSet into something +-- easier to work with claims2map :: ClaimsSet -> JSON.KeyMap JSON.Value claims2map = val2map . JSON.toJSON - where - val2map (JSON.Object o) = o - val2map _ = JSON.empty + where + val2map (JSON.Object o) = o + val2map _ = JSON.empty -{-| - Internal helper to generate HMAC-SHA256. When the jwt key in the - config file is a simple string rather than a JWK object, we'll - apply this function to it. --} +-- | +-- Internal helper to generate HMAC-SHA256. When the jwt key in the +-- config file is a simple string rather than a JWK object, we'll +-- apply this function to it. hs256jwk :: ByteString -> JWK hs256jwk key = fromKeyMaterial km & jwkUse ?~ Sig & jwkAlg ?~ JWSAlg HS256 - where - km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key)) + where + km = OctKeyMaterial (OctKeyParameters (JOSE.Types.Base64Octets key)) parseJWK :: ByteString -> JWK parseJWK str = - fromMaybe (hs256jwk str) (JSON.decode (toS str) :: Maybe JWK) + fromMaybe (hs256jwk str) (JSON.decode (fromStrict str) :: Maybe JWK) diff --git a/src/PostgresWebsockets/Config.hs b/src/PostgresWebsockets/Config.hs index 3436845..7215e56 100644 --- a/src/PostgresWebsockets/Config.hs +++ b/src/PostgresWebsockets/Config.hs @@ -14,6 +14,7 @@ module PostgresWebsockets.Config ) where +import APrelude import qualified Data.ByteString as BS import qualified Data.ByteString.Base64 as B64 import Data.String (IsString (..)) @@ -22,8 +23,6 @@ import Data.Version (versionBranch) import Env import Network.Wai.Handler.Warp import Paths_postgres_websockets (version) -import Protolude hiding (intercalate, optional, replace, toS, (<>)) -import Protolude.Conv -- | Config file settings for the server data AppConfig = AppConfig @@ -45,7 +44,7 @@ data AppConfig = AppConfig -- | User friendly version number prettyVersion :: Text -prettyVersion = intercalate "." $ map show $ versionBranch version +prettyVersion = intercalate "." $ map showText $ versionBranch version -- | Load all postgres-websockets config from Environment variables. This can be used to use just the middleware or to feed into warpSettings loadConfig :: IO AppConfig @@ -58,9 +57,9 @@ loadConfig = -- | Given a shutdown handler and an AppConfig builds a Warp Settings to start a stand-alone server warpSettings :: (IO () -> IO ()) -> AppConfig -> Settings warpSettings waitForShutdown AppConfig {..} = - setHost (fromString $ toS configHost) + setHost (fromString $ unpack configHost) . setPort configPort - . setServerName (toS $ "postgres-websockets/" <> prettyVersion) + . setServerName ("postgres-websockets/" <> encodeUtf8 prettyVersion) . setTimeout 3600 . setInstallShutdownHandler waitForShutdown . setGracefulShutdownTimeout (Just 5) @@ -72,7 +71,8 @@ warpSettings waitForShutdown AppConfig {..} = readOptions :: IO AppConfig readOptions = Env.parse (header "You need to configure some environment variables to start the service.") $ - AppConfig <$> var (str <=< nonempty) "PGWS_DB_URI" (help "String to connect to PostgreSQL") + AppConfig + <$> var (str <=< nonempty) "PGWS_DB_URI" (help "String to connect to PostgreSQL") <*> optional (var str "PGWS_ROOT_PATH" (help "Root path to serve static files, unset to disable.")) <*> var str "PGWS_HOST" (def "*4" <> helpDef show <> help "Address the server will listen for websocket connections") <*> var auto "PGWS_PORT" (def 3000 <> helpDef show <> help "Port the server will listen for websocket connections") @@ -96,7 +96,7 @@ loadDatabaseURIFile :: AppConfig -> IO AppConfig loadDatabaseURIFile conf@AppConfig {..} = case stripPrefix "@" configDatabase of Nothing -> pure conf - Just filename -> setDatabase . strip <$> readFile (toS filename) + Just filename -> setDatabase . strip . pack <$> readFile (unpack filename) where setDatabase uri = conf {configDatabase = uri} @@ -112,7 +112,7 @@ loadSecretFile conf = extractAndTransform secret transformString isB64 =<< case stripPrefix "@" s of Nothing -> return . encodeUtf8 $ s - Just filename -> chomp <$> BS.readFile (toS filename) + Just filename -> chomp <$> BS.readFile (unpack filename) where chomp bs = fromMaybe bs (BS.stripSuffix "\n" bs) diff --git a/src/PostgresWebsockets/Context.hs b/src/PostgresWebsockets/Context.hs index a944f14..86e25ca 100644 --- a/src/PostgresWebsockets/Context.hs +++ b/src/PostgresWebsockets/Context.hs @@ -7,6 +7,7 @@ module PostgresWebsockets.Context ) where +import APrelude import Control.AutoUpdate ( defaultUpdateSettings, mkAutoUpdate, @@ -18,8 +19,6 @@ import qualified Hasql.Pool.Config as P import PostgresWebsockets.Broadcast (Multiplexer) import PostgresWebsockets.Config (AppConfig (..)) import PostgresWebsockets.HasqlBroadcast (newHasqlBroadcaster) -import Protolude hiding (toS) -import Protolude.Conv data Context = Context { ctxConfig :: AppConfig, @@ -33,15 +32,15 @@ mkContext :: AppConfig -> IO () -> IO Context mkContext conf@AppConfig {..} shutdownServer = do Context conf <$> P.acquire config - <*> newHasqlBroadcaster shutdown (toS configListenChannel) configRetries configReconnectInterval pgSettings + <*> newHasqlBroadcaster shutdown configListenChannel configRetries configReconnectInterval pgSettings <*> mkGetTime where config = P.settings [P.staticConnectionSettings pgSettings] shutdown = maybe shutdownServer - (const $ putText "Producer thread is dead") + (const $ putStrLn "Producer thread is dead") configReconnectInterval mkGetTime :: IO (IO UTCTime) mkGetTime = mkAutoUpdate defaultUpdateSettings {updateAction = getCurrentTime} - pgSettings = toS configDatabase + pgSettings = encodeUtf8 configDatabase diff --git a/src/PostgresWebsockets/HasqlBroadcast.hs b/src/PostgresWebsockets/HasqlBroadcast.hs index c622aff..9bb39a2 100644 --- a/src/PostgresWebsockets/HasqlBroadcast.hs +++ b/src/PostgresWebsockets/HasqlBroadcast.hs @@ -15,11 +15,11 @@ module PostgresWebsockets.HasqlBroadcast ) where +import APrelude import Control.Retry (RetryStatus (..), capDelay, exponentialBackoff, retrying) import Data.Aeson (Value (..), decode) -import qualified Data.Aeson.KeyMap as JSON import qualified Data.Aeson.Key as Key - +import qualified Data.Aeson.KeyMap as JSON import Data.Either.Combinators (mapBoth) import Data.Function (id) import GHC.Show @@ -30,8 +30,6 @@ import Hasql.Notifications import qualified Hasql.Session as H import qualified Hasql.Statement as H import PostgresWebsockets.Broadcast -import Protolude hiding (putErrLn, show, toS) -import Protolude.Conv -- | Returns a multiplexer from a connection URI, keeps trying to connect in case there is any error. -- This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners @@ -44,7 +42,7 @@ newHasqlBroadcaster onConnectionFailure ch maxRetries checkInterval = newHasqlBr -- This function also spawns a thread that keeps relaying the messages from the database to the multiplexer's listeners newHasqlBroadcasterOrError :: IO () -> Text -> ByteString -> IO (Either ByteString Multiplexer) newHasqlBroadcasterOrError onConnectionFailure ch = - acquire >=> (sequence . mapBoth (toSL . show) (newHasqlBroadcasterForConnection . return)) + acquire >=> (sequence . mapBoth showBS (newHasqlBroadcasterForConnection . return)) where newHasqlBroadcasterForConnection = newHasqlBroadcasterForChannel onConnectionFailure ch Nothing @@ -60,7 +58,7 @@ tryUntilConnected maxRetries = shouldRetry RetryStatus {..} con = case con of Left err -> do - putErrLn $ "Error connecting notification listener to database: " <> (toS . show) err + putErrLn $ "Error connecting notification listener to database: " <> showText err pure $ rsIterNumber < maxRetries - 1 _ -> return False @@ -94,16 +92,16 @@ newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval getCon = do return multi where toMsg :: Text -> Text -> Message - toMsg c m = case decode (toS m) of + toMsg c m = case decode (fromStrict $ encodeUtf8 m) of Just v -> Message (channelDef c v) m Nothing -> Message c m lookupStringDef :: Text -> Text -> Value -> Text lookupStringDef key d (Object obj) = - case lookupDefault (String $ toS d) key obj of - String s -> toS s - _ -> toS d - lookupStringDef _ d _ = toS d + case lookupDefault (String d) key obj of + String s -> s + _ -> d + lookupStringDef _ d _ = d lookupDefault d key obj = fromMaybe d $ JSON.lookup (Key.fromText key) obj @@ -116,12 +114,9 @@ newHasqlBroadcasterForChannel onConnectionFailure ch checkInterval getCon = do con <- getCon listen con $ toPgIdentifier ch waitForNotifications - (\c m -> atomically $ writeTQueue msgQ $ toMsg (toS c) (toS m)) + (\c m -> atomically $ writeTQueue msgQ $ toMsg (decodeUtf8 c) (decodeUtf8 m)) con -putErrLn :: Text -> IO () -putErrLn = hPutStrLn stderr - isListening :: Connection -> Text -> IO Bool isListening con ch = do resultOrError <- H.run session con @@ -136,4 +131,4 @@ isListeningStatement = where sql = "select exists (select * from pg_stat_activity where datname = current_database() and query ilike $1);" encoder = HE.param $ HE.nonNullable HE.text - decoder = HD.singleRow (HD.column (HD.nonNullable HD.bool)) \ No newline at end of file + decoder = HD.singleRow (HD.column (HD.nonNullable HD.bool)) diff --git a/src/PostgresWebsockets/Middleware.hs b/src/PostgresWebsockets/Middleware.hs index 7c0108b..e0e06d9 100644 --- a/src/PostgresWebsockets/Middleware.hs +++ b/src/PostgresWebsockets/Middleware.hs @@ -1,53 +1,50 @@ -{-| -Module : PostgresWebsockets.Middleware -Description : PostgresWebsockets WAI middleware, add functionality to any WAI application. - -Allow websockets connections that will communicate with the database through LISTEN/NOTIFY channels. --} {-# LANGUAGE DeriveGeneric #-} +-- | +-- Module : PostgresWebsockets.Middleware +-- Description : PostgresWebsockets WAI middleware, add functionality to any WAI application. +-- +-- Allow websockets connections that will communicate with the database through LISTEN/NOTIFY channels. module PostgresWebsockets.Middleware - ( postgresWsMiddleware - ) where + ( postgresWsMiddleware, + ) +where -import Protolude hiding (toS) -import Protolude.Conv -import Data.Time.Clock (UTCTime) -import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime) +import APrelude import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm) +import qualified Data.Aeson as A +import qualified Data.Aeson.Key as Key +import qualified Data.Aeson.KeyMap as A +import qualified Data.ByteString.Lazy as BL +import qualified Data.Text as T +import Data.Time.Clock (UTCTime) +import Data.Time.Clock.POSIX (posixSecondsToUTCTime, utcTimeToPOSIXSeconds) import qualified Hasql.Notifications as H import qualified Hasql.Pool as H import qualified Network.Wai as Wai import qualified Network.Wai.Handler.WebSockets as WS import qualified Network.WebSockets as WS - -import qualified Data.Aeson as A -import qualified Data.Aeson.KeyMap as A -import qualified Data.Aeson.Key as Key - -import qualified Data.Text as T -import qualified Data.ByteString.Lazy as BL - import PostgresWebsockets.Broadcast (onMessage) -import PostgresWebsockets.Claims ( ConnectionInfo, validateClaims ) -import PostgresWebsockets.Context ( Context(..) ) -import PostgresWebsockets.Config (AppConfig(..)) import qualified PostgresWebsockets.Broadcast as B +import PostgresWebsockets.Claims (ConnectionInfo, validateClaims) +import PostgresWebsockets.Config (AppConfig (..)) +import PostgresWebsockets.Context (Context (..)) - -data Event = - WebsocketMessage +data Event + = WebsocketMessage | ConnectionOpen deriving (Show, Eq, Generic) data Message = Message - { claims :: A.Object - , event :: Event - , payload :: Text - , channel :: Text - } deriving (Show, Eq, Generic) + { claims :: A.Object, + event :: Event, + payload :: Text, + channel :: Text + } + deriving (Show, Eq, Generic) instance A.ToJSON Event + instance A.ToJSON Message -- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware. @@ -62,8 +59,8 @@ jwtExpirationStatusCode = 3001 -- when the websocket is closed a ConnectionClosed Exception is triggered -- this kills all children and frees resources for us wsApp :: Context -> WS.ServerApp -wsApp Context{..} pendingConn = - ctxGetTime >>= validateClaims requestChannel (configJwtSecret ctxConfig) (toS jwtToken) >>= either rejectRequest forkSessions +wsApp Context {..} pendingConn = + ctxGetTime >>= validateClaims requestChannel (configJwtSecret ctxConfig) (fromStrict $ encodeUtf8 jwtToken) >>= either rejectRequest forkSessions where hasRead m = m == ("r" :: Text) || m == ("rw" :: Text) hasWrite m = m == ("w" :: Text) || m == ("rw" :: Text) @@ -71,10 +68,10 @@ wsApp Context{..} pendingConn = rejectRequest :: Text -> IO () rejectRequest msg = do putErrLn $ "Rejecting Request: " <> msg - WS.rejectRequest pendingConn (toS msg) + WS.rejectRequest pendingConn (encodeUtf8 msg) -- the URI has one of the two formats - /:jwt or /:channel/:jwt - pathElements = T.split (== '/') $ T.drop 1 $ (toSL . WS.requestPath) $ WS.pendingRequest pendingConn + pathElements = T.split (== '/') $ T.drop 1 $ (decodeUtf8 . WS.requestPath) $ WS.pendingRequest pendingConn jwtToken = case length pathElements `compare` 1 of GT -> headDef "" $ tailSafe pathElements @@ -85,35 +82,37 @@ wsApp Context{..} pendingConn = _ -> Nothing forkSessions :: ConnectionInfo -> IO () forkSessions (chs, mode, validClaims) = do - -- We should accept only after verifying JWT - conn <- WS.acceptRequest pendingConn - -- Fork a pinging thread to ensure browser connections stay alive - WS.withPingThread conn 30 (pure ()) $ do - case A.lookup "exp" validClaims of - Just (A.Number expClaim) -> do - connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString)) - setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim) - Just _ -> pure () - Nothing -> pure () - - let sendNotification msg channel = sendMessageWithTimestamp $ websocketMessageForChannel msg channel - sendMessageToDatabase = sendToDatabase ctxPool (configListenChannel ctxConfig) - sendMessageWithTimestamp = timestampMessage ctxGetTime >=> sendMessageToDatabase - websocketMessageForChannel = Message validClaims WebsocketMessage - connectionOpenMessage = Message validClaims ConnectionOpen - - case configMetaChannel ctxConfig of - Nothing -> pure () - Just ch -> sendMessageWithTimestamp $ connectionOpenMessage (toS $ T.intercalate "," chs) ch - - when (hasRead mode) $ - forM_ chs $ flip (onMessage ctxMulti) $ WS.sendTextData conn . B.payload - - when (hasWrite mode) $ - notifySession conn sendNotification chs - - waitForever <- newEmptyMVar - void $ takeMVar waitForever + -- We should accept only after verifying JWT + conn <- WS.acceptRequest pendingConn + -- Fork a pinging thread to ensure browser connections stay alive + WS.withPingThread conn 30 (pure ()) $ do + case A.lookup "exp" validClaims of + Just (A.Number expClaim) -> do + connectionExpirer <- newAlarmClock $ const (WS.sendCloseCode conn jwtExpirationStatusCode ("JWT expired" :: ByteString)) + setAlarm connectionExpirer (posixSecondsToUTCTime $ realToFrac expClaim) + Just _ -> pure () + Nothing -> pure () + + let sendNotification msg channel = sendMessageWithTimestamp $ websocketMessageForChannel msg channel + sendMessageToDatabase = sendToDatabase ctxPool (configListenChannel ctxConfig) + sendMessageWithTimestamp = timestampMessage ctxGetTime >=> sendMessageToDatabase + websocketMessageForChannel = Message validClaims WebsocketMessage + connectionOpenMessage = Message validClaims ConnectionOpen + + case configMetaChannel ctxConfig of + Nothing -> pure () + Just ch -> sendMessageWithTimestamp $ connectionOpenMessage (T.intercalate "," chs) ch + + when (hasRead mode) $ + forM_ chs $ + flip (onMessage ctxMulti) $ + WS.sendTextData conn . B.payload + + when (hasWrite mode) $ + notifySession conn sendNotification chs + + waitForever <- newEmptyMVar + void $ takeMVar waitForever -- Having both channel and claims as parameters seem redundant -- But it allows the function to ignore the claims structure and the source @@ -124,16 +123,16 @@ notifySession wsCon sendToChannel chs = where relayData = do msg <- WS.receiveData wsCon - forM_ chs (sendToChannel msg . toS) + forM_ chs (sendToChannel msg) sendToDatabase :: H.Pool -> Text -> Message -> IO () sendToDatabase pool dbChannel = notify . jsonMsg where - notify = void . H.notifyPool pool dbChannel . toS + notify = void . H.notifyPool pool dbChannel . decodeUtf8 jsonMsg = BL.toStrict . A.encode timestampMessage :: IO UTCTime -> Message -> IO Message -timestampMessage getTime msg@Message{..} = do +timestampMessage getTime msg@Message {..} = do time <- utcTimeToPOSIXSeconds <$> getTime - return $ msg{ claims = A.insert (Key.fromText "message_delivered_at") (A.Number $ realToFrac time) claims} + return $ msg {claims = A.insert (Key.fromText "message_delivered_at") (A.Number $ realToFrac time) claims} diff --git a/src/PostgresWebsockets/Server.hs b/src/PostgresWebsockets/Server.hs index e579777..1cf33ba 100644 --- a/src/PostgresWebsockets/Server.hs +++ b/src/PostgresWebsockets/Server.hs @@ -6,6 +6,7 @@ module PostgresWebsockets.Server ) where +import APrelude import Network.HTTP.Types (status200) import Network.Wai (Application, responseLBS) import Network.Wai.Application.Static (defaultFileServerSettings, staticApp) @@ -15,13 +16,12 @@ import Network.Wai.Middleware.RequestLogger (logStdout) import PostgresWebsockets.Config (AppConfig (..), warpSettings) import PostgresWebsockets.Context (mkContext) import PostgresWebsockets.Middleware (postgresWsMiddleware) -import Protolude -- | Start a stand-alone warp server using the parameters from AppConfig and a opening a database connection pool. serve :: AppConfig -> IO () serve conf@AppConfig {..} = do shutdownSignal <- newEmptyMVar - putStrLn $ ("Listening on port " :: Text) <> show configPort + putStrLn $ "Listening on port " <> show configPort let shutdown = putErrLn ("Broadcaster connection is dead" :: Text) >> putMVar shutdownSignal () ctx <- mkContext conf shutdown @@ -31,13 +31,13 @@ serve conf@AppConfig {..} = do app = postgresWsMiddleware ctx $ logStdout $ maybe dummyApp staticApp' configPath case (configCertificateFile, configKeyFile) of - (Just certificate, Just key) -> runTLS (tlsSettings (toS certificate) (toS key)) appSettings app + (Just certificate, Just key) -> runTLS (tlsSettings (unpack certificate) (unpack key)) appSettings app _ -> runSettings appSettings app die "Shutting down server..." where staticApp' :: Text -> Application - staticApp' = staticApp . defaultFileServerSettings . toS + staticApp' = staticApp . defaultFileServerSettings . unpack dummyApp :: Application dummyApp _ respond = respond $ responseLBS status200 [("Content-Type", "text/plain")] "Hello, Web!" diff --git a/test/BroadcastSpec.hs b/test/BroadcastSpec.hs index 1fa376a..393a3c6 100644 --- a/test/BroadcastSpec.hs +++ b/test/BroadcastSpec.hs @@ -1,11 +1,9 @@ module BroadcastSpec (spec) where -import Protolude +import APrelude import Control.Concurrent.STM.TQueue - -import Test.Hspec - import PostgresWebsockets.Broadcast +import Test.Hspec spec :: Spec spec = do @@ -13,18 +11,27 @@ spec = do it "opens a separate thread for a producer function" $ do output <- newTQueueIO :: IO (TQueue ThreadId) - void $ liftIO $ newMultiplexer (\_-> do - tid <- myThreadId - atomically $ writeTQueue output tid - ) (\_ -> return ()) + void $ + liftIO $ + newMultiplexer + ( \_ -> do + tid <- myThreadId + atomically $ writeTQueue output tid + ) + (\_ -> return ()) outMsg <- atomically $ readTQueue output myThreadId `shouldNotReturn` outMsg describe "relayMessages" $ it "relays a single message from producer to 1 listener on 1 test channel" $ do output <- newTQueueIO :: IO (TQueue Message) - multi <- liftIO $ newMultiplexer (\msgs-> - atomically $ writeTQueue msgs (Message "test" "payload")) (\_ -> return ()) + multi <- + liftIO $ + newMultiplexer + ( \msgs -> + atomically $ writeTQueue msgs (Message "test" "payload") + ) + (\_ -> return ()) void $ onMessage multi "test" $ atomically . writeTQueue output liftIO $ relayMessages multi diff --git a/test/ClaimsSpec.hs b/test/ClaimsSpec.hs index 358a7ca..9d5cdd9 100644 --- a/test/ClaimsSpec.hs +++ b/test/ClaimsSpec.hs @@ -1,12 +1,12 @@ module ClaimsSpec (spec) where -import Protolude - -import Test.Hspec -import Data.Aeson (Value (..), toJSON) +import APrelude +import Data.Aeson (Value (..), toJSON) import qualified Data.Aeson.KeyMap as JSON import Data.Time.Clock -import PostgresWebsockets.Claims +import PostgresWebsockets.Claims +import Test.Hspec +import Prelude secret :: ByteString secret = "reallyreallyreallyreallyverysafe" @@ -16,39 +16,60 @@ spec = describe "validate claims" $ do it "should invalidate an expired token" $ do time <- getCurrentTime - validateClaims Nothing secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0IiwiZXhwIjoxfQ.4rDYiMZFR2WHB7Eq4HMdvDP_BQZVtHIfyJgy0NshbHY" time + validateClaims + Nothing + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0IiwiZXhwIjoxfQ.4rDYiMZFR2WHB7Eq4HMdvDP_BQZVtHIfyJgy0NshbHY" + time `shouldReturn` Left "Token expired" it "request any channel from a token that does not have channels or channel claims should succeed" $ do time <- getCurrentTime - validateClaims (Just "test") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciJ9.jL5SsRFegNUlbBm8_okhHSujqLcKKZdDglfdqNl1_rY" time - `shouldReturn` Right (["test"], "r", JSON.fromList[("mode",String "r")]) + validateClaims + (Just "test") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciJ9.jL5SsRFegNUlbBm8_okhHSujqLcKKZdDglfdqNl1_rY" + time + `shouldReturn` Right (["test"], "r", JSON.fromList [("mode", String "r")]) it "requesting a channel that is set by and old style channel claim should work" $ do time <- getCurrentTime - validateClaims (Just "test") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" time - `shouldReturn` Right (["test"], "r", JSON.fromList[("mode",String "r"),("channel",String "test")]) + validateClaims + (Just "test") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWwiOiJ0ZXN0In0.1d4s-at2kWj8OSabHZHTbNh1dENF7NWy_r0ED3Rwf58" + time + `shouldReturn` Right (["test"], "r", JSON.fromList [("mode", String "r"), ("channel", String "test")]) it "no requesting channel should return all channels in the token" $ do time <- getCurrentTime - validateClaims Nothing secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJhbm90aGVyIHRlc3QiXX0.b9N8J8tPOPIxxFj5WJ7sWrmcL8ib63i8eirsRZTM9N0" time - `shouldReturn` Right (["test", "another test"], "r", JSON.fromList[("mode",String "r"),("channels", toJSON["test"::Text, "another test"::Text] ) ]) + validateClaims + Nothing + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJhbm90aGVyIHRlc3QiXX0.b9N8J8tPOPIxxFj5WJ7sWrmcL8ib63i8eirsRZTM9N0" + time + `shouldReturn` Right (["test", "another test"], "r", JSON.fromList [("mode", String "r"), ("channels", toJSON ["test" :: Text, "another test" :: Text])]) it "requesting a channel from the channels claim shoud return only the requested channel" $ do time <- getCurrentTime - validateClaims (Just "test") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" time - `shouldReturn` Right (["test"], "r", JSON.fromList[("mode",String "r"),("channels", toJSON ["test"::Text, "test2"] )]) + validateClaims + (Just "test") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" + time + `shouldReturn` Right (["test"], "r", JSON.fromList [("mode", String "r"), ("channels", toJSON ["test" :: Text, "test2"])]) it "requesting a channel not from the channels claim shoud error" $ do time <- getCurrentTime - validateClaims (Just "notAllowed") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" time - `shouldReturn` Left "No allowed channels" + validateClaims + (Just "notAllowed") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJtb2RlIjoiciIsImNoYW5uZWxzIjpbInRlc3QiLCJ0ZXN0MiJdfQ.MumdJ5FpFX4Z6SJD3qsygVF0r9vqxfqhj5J30O32N0k" + time + `shouldReturn` Left "No allowed channels" it "requesting a channel with no mode fails" $ do time <- getCurrentTime - validateClaims (Just "test") secret - "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJjaGFubmVscyI6WyJ0ZXN0IiwidGVzdDIiXX0.akC1PEYk2DEZtLP2XjC6qXOGZJejmPx49qv-VeEtQYQ" time + validateClaims + (Just "test") + secret + "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJjaGFubmVscyI6WyJ0ZXN0IiwidGVzdDIiXX0.akC1PEYk2DEZtLP2XjC6qXOGZJejmPx49qv-VeEtQYQ" + time `shouldReturn` Left "Missing mode" diff --git a/test/HasqlBroadcastSpec.hs b/test/HasqlBroadcastSpec.hs index 1cbdf67..8d00156 100644 --- a/test/HasqlBroadcastSpec.hs +++ b/test/HasqlBroadcastSpec.hs @@ -1,25 +1,23 @@ module HasqlBroadcastSpec (spec) where -import Protolude - -import Data.Function (id) -import Test.Hspec +import APrelude +import Hasql.Notifications import PostgresWebsockets.Broadcast import PostgresWebsockets.HasqlBroadcast -import Hasql.Notifications +import Test.Hspec spec :: Spec spec = describe "newHasqlBroadcaster" $ do - let newConnection connStr = - either (panic . show) id - <$> acquire connStr + let newConnection connStr = + either (panic . showText) id + <$> acquire connStr - it "relay messages sent to the appropriate database channel" $ do - multi <- either (panic .show) id <$> newHasqlBroadcasterOrError (pure ()) "postgres-websockets" "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" - msg <- liftIO newEmptyMVar - onMessage multi "test" $ putMVar msg + it "relay messages sent to the appropriate database channel" $ do + multi <- either (panic . showText) id <$> newHasqlBroadcasterOrError (pure ()) "postgres-websockets" "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" + msg <- liftIO newEmptyMVar + onMessage multi "test" $ putMVar msg - con <- newConnection "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" - void $ notify con (toPgIdentifier "postgres-websockets") "{\"channel\": \"test\", \"payload\": \"hello there\"}" + con <- newConnection "postgres://postgres:roottoor@localhost:5432/postgres_ws_test" + void $ notify con (toPgIdentifier "postgres-websockets") "{\"channel\": \"test\", \"payload\": \"hello there\"}" - readMVar msg `shouldReturn` Message "test" "{\"channel\": \"test\", \"payload\": \"hello there\"}" + readMVar msg `shouldReturn` Message "test" "{\"channel\": \"test\", \"payload\": \"hello there\"}" diff --git a/test/ServerSpec.hs b/test/ServerSpec.hs index 626921e..6666043 100644 --- a/test/ServerSpec.hs +++ b/test/ServerSpec.hs @@ -1,12 +1,12 @@ module ServerSpec (spec) where +import APrelude import Control.Lens import Data.Aeson.Lens import Network.Socket (withSocketsDo) import qualified Network.WebSockets as WS import PostgresWebsockets import PostgresWebsockets.Config -import Protolude import Test.Hspec testServerConfig :: AppConfig @@ -46,7 +46,7 @@ sendWsData uri msg = WS.runClient "127.0.0.1" (configPort testServerConfig) - (toS uri) + (unpack uri) (`WS.sendTextData` msg) testChannel :: Text @@ -67,7 +67,7 @@ waitForWsData uri = do WS.runClient "127.0.0.1" (configPort testServerConfig) - (toS uri) + (unpack uri) ( \c -> do m <- WS.receiveData c putMVar msg m @@ -84,7 +84,7 @@ waitForMultipleWsData messageCount uri = do WS.runClient "127.0.0.1" (configPort testServerConfig) - (toS uri) + (unpack uri) ( \c -> do m <- replicateM messageCount (WS.receiveData c) putMVar msg m