close websocket connection on JWT expiry (fix #578) (#2156)

This commit is contained in:
Rakesh Emmadi
2019-05-14 11:54:46 +05:30
committed by Vamshi Surabhi
parent ee783e142e
commit c6f40df6d5
8 changed files with 160 additions and 75 deletions

View File

@@ -16,6 +16,7 @@ import qualified Data.CaseInsensitive as CI
import qualified Data.HashMap.Strict as Map
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import qualified Data.Time.Clock as TC
import qualified Language.GraphQL.Draft.Syntax as G
import qualified ListT
import qualified Network.HTTP.Client as H
@@ -37,10 +38,10 @@ import qualified Hasura.Logging as L
import Hasura.Prelude
import Hasura.RQL.Types
import Hasura.RQL.Types.Error (Code (StartFailed))
import Hasura.Server.Auth (AuthMode,
getUserInfo)
import Hasura.Server.Auth (AuthMode, getUserInfoWithExpTime)
import Hasura.Server.Cors
import Hasura.Server.Utils (bsToTxt)
import Hasura.Server.Utils (bsToTxt,
diffTimeToMicro)
type OperationMap
= STMMap.Map OperationId (LQ.LiveQueryId, Maybe OperationName)
@@ -59,12 +60,13 @@ data WSConnState
= CSNotInitialised !WsHeaders
| CSInitError Text
-- headers from the client (in conn params) to forward to the remote schema
| CSInitialised UserInfo [H.Header]
-- and JWT expiry time if any
| CSInitialised UserInfo (Maybe TC.UTCTime) [H.Header]
data WSConnData
= WSConnData
-- the role and headers are set only on connection_init message
{ _wscUser :: !(IORef.IORef WSConnState)
{ _wscUser :: !(STM.TVar WSConnState)
-- we only care about subscriptions,
-- the other operations (query/mutations)
-- are not tracked here
@@ -109,6 +111,7 @@ data WSLog
= WSLog
{ _wslWebsocketId :: !WS.WSId
, _wslUser :: !(Maybe UserVars)
, _wslJwtExpiry :: !(Maybe TC.UTCTime)
, _wslEvent :: !WSEvent
, _wslMsg :: !(Maybe Text)
} deriving (Show, Eq)
@@ -145,18 +148,30 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do
sendMsg wsConn SMConnKeepAlive
threadDelay $ 5 * 1000 * 1000
jwtExpiryHandler wsConn = do
expTime <- STM.atomically $ do
connState <- STM.readTVar $ (_wscUser . WS.getData) wsConn
case connState of
CSNotInitialised _ -> STM.retry
CSInitError _ -> STM.retry
CSInitialised _ expTimeM _ ->
maybe STM.retry return expTimeM
currTime <- TC.getCurrentTime
threadDelay $ diffTimeToMicro $ TC.diffUTCTime expTime currTime
accept hdrs errType = do
logger $ WSLog wsId Nothing EAccepted Nothing
logger $ WSLog wsId Nothing Nothing EAccepted Nothing
connData <- WSConnData
<$> IORef.newIORef (CSNotInitialised hdrs)
<$> STM.newTVarIO (CSNotInitialised hdrs)
<*> STMMap.newIO
<*> pure errType
let acceptRequest = WS.defaultAcceptRequest
{ WS.acceptSubprotocol = Just "graphql-ws"}
return $ Right (connData, acceptRequest, Just keepAliveAction)
return $ Right $ WS.AcceptWith connData acceptRequest
(Just keepAliveAction) (Just jwtExpiryHandler)
reject qErr = do
logger $ WSLog wsId Nothing (ERejected qErr) Nothing
logger $ WSLog wsId Nothing Nothing (ERejected qErr) Nothing
return $ Left $ WS.RejectRequest
(H.statusCode $ qeStatus qErr)
(H.statusMessage $ qeStatus qErr) []
@@ -178,7 +193,7 @@ onConn (L.Logger logger) corsPolicy wsId requestHead = do
if readCookie
then return reqHdrs
else do
liftIO $ logger $ WSLog wsId Nothing EAccepted (Just corsNote)
liftIO $ logger $ WSLog wsId Nothing Nothing EAccepted (Just corsNote)
return $ filter (\h -> fst h /= "Cookie") reqHdrs
CCAllowedOrigins ds
-- if the origin is in our cors domains, no error
@@ -212,9 +227,9 @@ onStart serverEnv wsConn (StartMsg opId q) msgRaw = catchAndIgnore $ do
when (isJust opM) $ withComplete $ sendStartErr $
"an operation already exists with this id: " <> unOperationId opId
userInfoM <- liftIO $ IORef.readIORef userInfoR
userInfoM <- liftIO $ STM.readTVarIO userInfoR
(userInfo, reqHdrs) <- case userInfoM of
CSInitialised userInfo reqHdrs -> return (userInfo, reqHdrs)
CSInitialised userInfo _ reqHdrs -> return (userInfo, reqHdrs)
CSInitError initErr -> do
let e = "cannot start as connection_init failed with : " <> initErr
withComplete $ sendStartErr e
@@ -366,11 +381,13 @@ logWSEvent
:: (MonadIO m)
=> L.Logger -> WSConn -> WSEvent -> m ()
logWSEvent (L.Logger logger) wsConn wsEv = do
userInfoME <- liftIO $ IORef.readIORef userInfoR
let userInfoM = case userInfoME of
CSInitialised userInfo _ -> return $ userVars userInfo
_ -> Nothing
liftIO $ logger $ WSLog wsId userInfoM wsEv Nothing
userInfoME <- liftIO $ STM.readTVarIO userInfoR
let (userVarsM, jwtExpM) = case userInfoME of
CSInitialised userInfo jwtM _ -> ( Just $ userVars userInfo
, jwtM
)
_ -> (Nothing, Nothing)
liftIO $ logger $ WSLog wsId userVarsM jwtExpM wsEv Nothing
where
WSConnData userInfoR _ _ = WS.getData wsConn
wsId = WS.getWSId wsConn
@@ -379,18 +396,18 @@ onConnInit
:: (MonadIO m)
=> L.Logger -> H.Manager -> WSConn -> AuthMode -> Maybe ConnParams -> m ()
onConnInit logger manager wsConn authMode connParamsM = do
headers <- mkHeaders <$> liftIO (IORef.readIORef (_wscUser $ WS.getData wsConn))
res <- runExceptT $ getUserInfo logger manager headers authMode
headers <- mkHeaders <$> liftIO (STM.readTVarIO (_wscUser $ WS.getData wsConn))
res <- runExceptT $ getUserInfoWithExpTime logger manager headers authMode
case res of
Left e -> do
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
liftIO $ STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) $
CSInitError $ qeError e
let connErr = ConnErrMsg $ qeError e
logWSEvent logger wsConn $ EConnErr connErr
sendMsg wsConn $ SMConnErr connErr
Right userInfo -> do
liftIO $ IORef.writeIORef (_wscUser $ WS.getData wsConn) $
CSInitialised userInfo paramHeaders
Right (userInfo, expTimeM) -> do
liftIO $ STM.atomically $ STM.writeTVar (_wscUser $ WS.getData wsConn) $
CSInitialised userInfo expTimeM paramHeaders
sendMsg wsConn SMConnAck
-- TODO: send it periodically? Why doesn't apollo's protocol use
-- ping/pong frames of websocket spec?
@@ -411,10 +428,9 @@ onConnInit logger manager wsConn authMode connParamsM = do
onClose
:: L.Logger
-> LQ.LiveQueriesState
-> WS.ConnectionException
-> WSConn
-> IO ()
onClose logger lqMap _ wsConn = do
onClose logger lqMap wsConn = do
logWSEvent logger wsConn EClosed
operations <- STM.atomically $ ListT.toList $ STMMap.listT opMap
void $ A.forConcurrently operations $ \(_, (lqId, _)) ->

View File

@@ -9,6 +9,7 @@ module Hasura.GraphQL.Transport.WebSocket.Server
, closeConn
, sendMsg
, AcceptWith(..)
, OnConnH
, OnCloseH
, OnMessageH
@@ -51,6 +52,7 @@ data WSEvent
| ERejected
| EMessageReceived !TBS.TByteString
| EMessageSent !TBS.TByteString
| EJwtExpired
| ECloseReceived
| ECloseSent !TBS.TByteString
| EClosed
@@ -118,10 +120,17 @@ closeAll (WSServer (L.Logger writeLog) connMap) msg = do
return conns
void $ A.mapConcurrently (flip closeConn msg . snd) conns
type AcceptWith a = (a, WS.AcceptRequest, Maybe (WSConn a -> IO ()))
data AcceptWith a
= AcceptWith
{ _awData :: !a
, _awReq :: !WS.AcceptRequest
, _awKeepAlive :: !(Maybe (WSConn a -> IO ()))
, _awOnJwtExpiry :: !(Maybe (WSConn a -> IO ()))
}
type OnConnH a = WSId -> WS.RequestHead ->
IO (Either WS.RejectRequest (AcceptWith a))
type OnCloseH a = WS.ConnectionException -> WSConn a -> IO ()
type OnCloseH a = WSConn a -> IO ()
type OnMessageH a = WSConn a -> BL.ByteString -> IO ()
data WSHandlers a
@@ -149,7 +158,7 @@ createServerApp (WSServer logger@(L.Logger writeLog) connMap) wsHandlers pending
WS.rejectRequestWith pendingConn rejectRequest
writeLog $ WSLog wsId ERejected
onAccept wsId (a, acceptWithParams, keepAliveM) = do
onAccept wsId (AcceptWith a acceptWithParams keepAliveM onJwtExpiryM) = do
conn <- WS.acceptRequestWith pendingConn acceptWithParams
writeLog $ WSLog wsId EAccepted
@@ -168,19 +177,23 @@ createServerApp (WSServer logger@(L.Logger writeLog) connMap) wsHandlers pending
writeLog $ WSLog wsId $ EMessageSent $ TBS.fromLBS msg
keepAliveRefM <- forM keepAliveM $ \action -> A.async $ action wsConn
onJwtExpiryRefM <- forM onJwtExpiryM $ \action -> A.async $ action wsConn
-- terminates on WS.ConnectionException
let waitOnRefs = maybeToList keepAliveRefM <> [rcvRef, sendRef]
-- terminates on WS.ConnectionException and JWT expiry
let waitOnRefs = catMaybes [keepAliveRefM, onJwtExpiryRefM]
<> [rcvRef, sendRef]
res <- try $ A.waitAnyCancel waitOnRefs
case res of
Left e -> do
Left ( _ :: WS.ConnectionException) -> do
writeLog $ WSLog (_wcConnId wsConn) ECloseReceived
onConnClose e wsConn
-- this will never happen as both the threads never finish
Right _ -> return ()
onConnClose wsConn
-- this will happen when jwt is expired
Right _ -> do
writeLog $ WSLog (_wcConnId wsConn) EJwtExpired
onConnClose wsConn
onConnClose e wsConn = do
onConnClose wsConn = do
STM.atomically $ STMMap.delete (_wcConnId wsConn) connMap
_hOnClose wsHandlers e wsConn
_hOnClose wsHandlers wsConn
writeLog $ WSLog (_wcConnId wsConn) EClosed

View File

@@ -11,9 +11,7 @@ module Hasura.RQL.Types.Permission
, getVarVal
, roleFromVars
, UserInfo
, userRole
, userVars
, UserInfo(..)
, mkUserInfo
, userInfoToList
, adminUserInfo
@@ -25,7 +23,9 @@ module Hasura.RQL.Types.Permission
) where
import Hasura.Prelude
import Hasura.Server.Utils (adminSecretHeader, deprecatedAccessKeyHeader, userRoleHeader)
import Hasura.Server.Utils (adminSecretHeader,
deprecatedAccessKeyHeader,
userRoleHeader)
import Hasura.SQL.Types
import qualified Database.PG.Query as Q

View File

@@ -1,8 +1,6 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
module Hasura.Server.Auth
( getUserInfo
, getUserInfoWithExpTime
, AuthMode(..)
, mkAuthMode
, AdminSecret (..)
@@ -23,6 +21,7 @@ import Control.Exception (try)
import Control.Lens
import Data.Aeson
import Data.IORef (newIORef)
import Data.Time.Clock (UTCTime)
import qualified Data.Aeson as J
import qualified Data.ByteString.Lazy as BL
@@ -102,11 +101,13 @@ mkAuthMode mAdminSecret mWebHook mJwtSecret mUnAuthRole httpManager lCtx =
(Just _, Just _, Just _) -> throwError
"Fatal Error: Both webhook and JWT mode cannot be enabled at the same time"
where
requiresAdminScrtMsg = " requires --admin-secret (HASURA_GRAPHQL_ADMIN_SECRET) or --access-key (HASURA_GRAPHQL_ACCESS_KEY) to be set"
requiresAdminScrtMsg =
" requires --admin-secret (HASURA_GRAPHQL_ADMIN_SECRET) or "
<> " --access-key (HASURA_GRAPHQL_ACCESS_KEY) to be set"
unAuthRoleNotReqForWebHook =
when (isJust mUnAuthRole) $
throwError $ "Fatal Error: --unauthorized-role (HASURA_GRAPHQL_UNAUTHORIZED_ROLE) is not allowed"
<> " when --auth-hook (HASURA_GRAPHQL_AUTH_HOOK) is set"
when (isJust mUnAuthRole) $ throwError $
"Fatal Error: --unauthorized-role (HASURA_GRAPHQL_UNAUTHORIZED_ROLE) is not allowed"
<> " when --auth-hook (HASURA_GRAPHQL_AUTH_HOOK) is set"
mkJwtCtx
:: ( MonadIO m
@@ -219,7 +220,6 @@ userInfoFromAuthHook logger manager hook reqHeaders = do
, "Cache-Control", "Connection", "DNT"
]
getUserInfo
:: (MonadIO m, MonadError QErr m)
=> L.Logger
@@ -227,17 +227,29 @@ getUserInfo
-> [N.Header]
-> AuthMode
-> m UserInfo
getUserInfo logger manager rawHeaders = \case
getUserInfo l m r a = fst <$> getUserInfoWithExpTime l m r a
AMNoAuth -> return userInfoFromHeaders
getUserInfoWithExpTime
:: (MonadIO m, MonadError QErr m)
=> L.Logger
-> H.Manager
-> [N.Header]
-> AuthMode
-> m (UserInfo, Maybe UTCTime)
getUserInfoWithExpTime logger manager rawHeaders = \case
AMNoAuth -> return (userInfoFromHeaders, Nothing)
AMAdminSecret adminScrt unAuthRole ->
case adminSecretM of
Just givenAdminScrt -> userInfoWhenAdminSecret adminScrt givenAdminScrt
Nothing -> userInfoWhenNoAdminSecret unAuthRole
Just givenAdminScrt ->
withNoExpTime $ userInfoWhenAdminSecret adminScrt givenAdminScrt
Nothing ->
withNoExpTime $ userInfoWhenNoAdminSecret unAuthRole
AMAdminSecretAndHook accKey hook ->
whenAdminSecretAbsent accKey (userInfoFromAuthHook logger manager hook rawHeaders)
whenAdminSecretAbsent accKey $
withNoExpTime $ userInfoFromAuthHook logger manager hook rawHeaders
AMAdminSecretAndJWT accKey jwtSecret unAuthRole ->
whenAdminSecretAbsent accKey (processJwt jwtSecret rawHeaders unAuthRole)
@@ -246,9 +258,10 @@ getUserInfo logger manager rawHeaders = \case
-- when admin secret is absent, run the action to retrieve UserInfo, otherwise
-- adminsecret override
whenAdminSecretAbsent ak action =
maybe action (userInfoWhenAdminSecret ak) $ adminSecretM
maybe action (withNoExpTime . userInfoWhenAdminSecret ak) adminSecretM
adminSecretM= foldl1 (<|>) $ map (flip getVarVal usrVars) [adminSecretHeader, deprecatedAccessKeyHeader]
adminSecretM= foldl1 (<|>) $
map (`getVarVal` usrVars) [adminSecretHeader, deprecatedAccessKeyHeader]
usrVars = mkUserVars $ hdrsToText rawHeaders
@@ -258,9 +271,13 @@ getUserInfo logger manager rawHeaders = \case
Nothing -> mkUserInfo adminRole usrVars
userInfoWhenAdminSecret key reqKey = do
when (reqKey /= getAdminSecret key) $ throw401 $ "invalid " <> adminSecretHeader <> "/" <> deprecatedAccessKeyHeader
when (reqKey /= getAdminSecret key) $ throw401 $
"invalid " <> adminSecretHeader <> "/" <> deprecatedAccessKeyHeader
return userInfoFromHeaders
userInfoWhenNoAdminSecret = \case
Nothing -> throw401 $ adminSecretHeader <> "/" <> deprecatedAccessKeyHeader <> " required, but not found"
Nothing -> throw401 $ adminSecretHeader <> "/"
<> deprecatedAccessKeyHeader <> " required, but not found"
Just role -> return $ mkUserInfo role usrVars
withNoExpTime a = (, Nothing) <$> a

View File

@@ -17,8 +17,8 @@ import Crypto.JWT
import Data.IORef (IORef, modifyIORef, readIORef)
import Data.List (find)
import Data.Time.Clock (NominalDiffTime, diffUTCTime,
getCurrentTime)
import Data.Time.Clock (NominalDiffTime, UTCTime,
diffUTCTime, getCurrentTime)
import Data.Time.Format (defaultTimeLocale, parseTimeM)
import Network.URI (URI)
@@ -28,7 +28,8 @@ import Hasura.Prelude
import Hasura.RQL.Types
import Hasura.Server.Auth.JWT.Internal (parseHmacKey, parseRsaKey)
import Hasura.Server.Auth.JWT.Logging
import Hasura.Server.Utils (bsToTxt, userRoleHeader)
import Hasura.Server.Utils (bsToTxt, diffTimeToMicro,
userRoleHeader)
import qualified Control.Concurrent as C
import qualified Data.Aeson as A
@@ -106,13 +107,12 @@ jwkRefreshCtrl
-> m ()
jwkRefreshCtrl lggr mngr url ref time =
void $ liftIO $ C.forkIO $ do
C.threadDelay $ delay time
C.threadDelay $ diffTimeToMicro time
forever $ do
res <- runExceptT $ updateJwkRef lggr mngr url ref
mTime <- either (const $ return Nothing) return res
C.threadDelay $ maybe (60 * aSecond) delay mTime
C.threadDelay $ maybe (60 * aSecond) diffTimeToMicro mTime
where
delay t = (floor (realToFrac t :: Double) - 10) * aSecond
aSecond = 1000 * 1000
@@ -172,7 +172,7 @@ processJwt
=> JWTCtx
-> HTTP.RequestHeaders
-> Maybe RoleName
-> m UserInfo
-> m (UserInfo, Maybe UTCTime)
processJwt jwtCtx headers mUnAuthRole =
maybe withoutAuthZHeader withAuthZHeader mAuthZHeader
where
@@ -183,7 +183,8 @@ processJwt jwtCtx headers mUnAuthRole =
withoutAuthZHeader = do
unAuthRole <- maybe missingAuthzHeader return mUnAuthRole
return $ mkUserInfo unAuthRole $ mkUserVars $ hdrsToText headers
return $ (, Nothing) $
mkUserInfo unAuthRole $ mkUserVars $ hdrsToText headers
missingAuthzHeader =
throw400 InvalidHeaders "Missing Authorization header in JWT authentication mode"
@@ -194,7 +195,7 @@ processAuthZHeader
=> JWTCtx
-> HTTP.RequestHeaders
-> BLC.ByteString
-> m UserInfo
-> m (UserInfo, Maybe UTCTime)
processAuthZHeader jwtCtx headers authzHeader = do
-- try to parse JWT token from Authorization header
jwt <- parseAuthzHeader
@@ -204,6 +205,7 @@ processAuthZHeader jwtCtx headers authzHeader = do
let claimsNs = fromMaybe defaultClaimNs $ jcxClaimNs jwtCtx
claimsFmt = jcxClaimsFormat jwtCtx
expTimeM = fmap (\(NumericDate t) -> t) $ claims ^. claimExp
-- see if the hasura claims key exist in the claims map
let mHasuraClaims = Map.lookup claimsNs $ claims ^. unregisteredClaims
@@ -227,7 +229,7 @@ processAuthZHeader jwtCtx headers authzHeader = do
-- transform the map of text:aeson-value -> text:text
metadata <- decodeJSON $ A.Object finalClaims
return $ mkUserInfo role $ mkUserVars $ Map.toList metadata
return $ (, expTimeM) $ mkUserInfo role $ mkUserVars $ Map.toList metadata
where
parseAuthzHeader = do

View File

@@ -4,6 +4,7 @@ import qualified Database.PG.Query.Connection as Q
import Data.Aeson
import Data.List.Split
import Data.Time.Clock
import Network.URI
import System.Environment
import System.Exit
@@ -143,3 +144,10 @@ matchRegex regex caseSensitive src =
fmapL :: (a -> a') -> Either a b -> Either a' b
fmapL fn (Left e) = Left (fn e)
fmapL _ (Right x) = pure x
-- diff time to micro seconds
diffTimeToMicro :: NominalDiffTime -> Int
diffTimeToMicro diff =
(floor (realToFrac diff :: Double) - 10) * aSecond
where
aSecond = 1000 * 1000

View File

@@ -1,10 +1,12 @@
from datetime import datetime, timedelta
import math
import json
import time
import yaml
import pytest
import jwt
from test_subscriptions import init_ws_conn
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
@@ -207,3 +209,29 @@ def gen_rsa_key():
encryption_algorithm=serialization.NoEncryption()
)
return pem
class TestSubscriptionJwtExpiry(object):
def test_jwt_expiry(self, hge_ctx, ws_client):
curr_time = datetime.now()
self.claims = {
'sub': '1234567890',
'name': 'John Doe',
'iat': math.floor(curr_time.timestamp())
}
self.claims['https://hasura.io/jwt/claims'] = mk_claims(hge_ctx.hge_jwt_conf, {
'x-hasura-user-id': '1',
'x-hasura-default-role': 'user',
'x-hasura-allowed-roles': ['user'],
})
exp = curr_time + timedelta(seconds=5)
self.claims['exp'] = round(exp.timestamp())
token = jwt.encode(self.claims, hge_ctx.hge_jwt_key, algorithm='RS512').decode('utf-8')
payload = {
'headers': {
'Authorization': 'Bearer ' + token
}
}
init_ws_conn(hge_ctx, ws_client, payload)
time.sleep(5)
assert ws_client.remote_closed == True, ws_client.remote_closed

View File

@@ -9,14 +9,16 @@ import yaml
Refer: https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init
'''
def init_ws_conn(hge_ctx, ws_client):
payload = {}
if hge_ctx.hge_key is not None:
payload = {
'headers' : {
'X-Hasura-Admin-Secret': hge_ctx.hge_key
def init_ws_conn(hge_ctx, ws_client, payload = None):
if payload is None:
payload = {}
if hge_ctx.hge_key is not None:
payload = {
'headers' : {
'X-Hasura-Admin-Secret': hge_ctx.hge_key
}
}
}
init_msg = {
'type': 'connection_init',
'payload': payload,
@@ -251,4 +253,3 @@ class TestSubscriptionLiveQueries(object):
@classmethod
def dir(cls):
return 'queries/subscriptions/live_queries'