diff --git a/CHANGELOG.md b/CHANGELOG.md index b26ed38673..2f37d21ca3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - #1536, Add string comparison feature for jwt-role-claim-key - @taimoorzaeem - #3747, Allow `not_null` value for the `is` operator - @taimoorzaeem - #2255, Apply `to_tsvector()` explicitly to the full-text search filtered column (excluding `tsvector` types) - @laurenceisla + - #3802, Add metric `pgrst_jwt_cache_size_bytes` in admin server - @taimoorzaeem ### Fixed diff --git a/docs/references/observability.rst b/docs/references/observability.rst index 04a261ba48..f3a69487ed 100644 --- a/docs/references/observability.rst +++ b/docs/references/observability.rst @@ -169,6 +169,20 @@ pgrst_db_pool_max Max pool connections. +JWT Cache Metric +---------------- + +Related to the :ref:`jwt_caching`. + +pgrst_jwt_cache_size_bytes +~~~~~~~~~~~~~~~~~~~~~~~~~~ + +======== ======= +**Type** Gauge +======== ======= + +The JWT cache size in bytes. + Traces ====== diff --git a/postgrest.cabal b/postgrest.cabal index 0cd6f101f5..d2acd71512 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -49,6 +49,7 @@ library PostgREST.App PostgREST.AppState PostgREST.Auth + PostgREST.Auth.Cache PostgREST.Auth.Types PostgREST.CLI PostgREST.Config @@ -90,7 +91,8 @@ library PostgREST.Response.GucHeader PostgREST.Response.Performance PostgREST.Version - other-modules: Paths_postgrest + other-modules: PostgREST.Internal + Paths_postgrest build-depends: base >= 4.9 && < 4.20 , HTTP >= 4000.3.7 && < 4000.5 , Ranged-sets >= 0.3 && < 0.5 @@ -109,6 +111,7 @@ library , either >= 4.4.1 && < 5.1 , extra >= 1.7.0 && < 2.0 , fuzzyset >= 0.2.4 && < 0.3 + , ghc-heap >= 9.4 && < 9.9 , hasql >= 1.6.1.1 && < 1.7 , hasql-dynamic-statements >= 0.3.1 && < 0.4 , hasql-notifications >= 0.2.2.2 && < 0.2.3 diff --git a/src/PostgREST/App.hs b/src/PostgREST/App.hs index 613c641278..e1408b5be4 100644 --- a/src/PostgREST/App.hs +++ b/src/PostgREST/App.hs @@ -15,7 +15,6 @@ module PostgREST.App , run ) where - import Control.Monad.Except (liftEither) import Data.Either.Combinators (mapLeft) import Data.Maybe (fromJust) diff --git a/src/PostgREST/Auth.hs b/src/PostgREST/Auth.hs index 51a3d81af1..4a83b3d95a 100644 --- a/src/PostgREST/Auth.hs +++ b/src/PostgREST/Auth.hs @@ -24,7 +24,6 @@ import qualified Data.Aeson.KeyMap as KM import qualified Data.Aeson.Types as JSON import qualified Data.ByteString as BS import qualified Data.ByteString.Lazy.Char8 as LBS -import qualified Data.Cache as C import qualified Data.Scientific as Sci import qualified Data.Text as T import qualified Data.Vault.Lazy as Vault @@ -40,12 +39,11 @@ import Data.Either.Combinators (mapLeft) import Data.List (lookup) import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) -import System.Clock (TimeSpec (..)) import System.IO.Unsafe (unsafePerformIO) import System.TimeIt (timeItT) -import PostgREST.AppState (AppState, getConfig, getJwtCache, - getTime) +import PostgREST.AppState (AppState, getConfig, getTime) +import PostgREST.Auth.Cache (getJWTFromCache) import PostgREST.Auth.Types (AuthResult (..)) import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath, JSPathExp (..)) @@ -153,7 +151,7 @@ middleware appState app req respond = do let token = fromMaybe "" $ Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req) parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf --- If DbPlanEnabled -> calculate JWT validation time +-- If ServerTimingEnabled -> calculate JWT validation time -- If JwtCacheMaxLifetime -> cache JWT validation result req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of (True, 0) -> do @@ -174,51 +172,6 @@ middleware appState app req respond = do app req' respond --- | Used to retrieve and insert JWT to JWT Cache -getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult) -getJWTFromCache appState token maxLifetime parseJwt utc = do - checkCache <- C.lookup (getJwtCache appState) token - authResult <- maybe parseJwt (pure . Right) checkCache - - case (authResult,checkCache) of - -- From comment: - -- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914 - -- - -- We purge expired cache entries on a cache miss - -- The reasoning is that: - -- - -- 1. We expect it to be rare (otherwise there is no point of the cache) - -- 2. It makes sure the cache is not growing (as inserting new entries - -- does garbage collection) - -- 3. Since this is time expiration based cache there is no real risk of - -- starvation - sooner or later we are going to have a cache miss. - - (Right res, Nothing) -> do -- cache miss - - let timeSpec = getTimeSpec res maxLifetime utc - - -- purge expired cache entries - C.purgeExpired jwtCache - - -- insert new cache entry - C.insert' jwtCache timeSpec token res - - _ -> pure () - - return authResult - where - jwtCache = getJwtCache appState - --- Used to extract JWT exp claim and add to JWT Cache -getTimeSpec :: AuthResult -> Int -> UTCTime -> Maybe TimeSpec -getTimeSpec res maxLifetime utc = do - let expireJSON = KM.lookup "exp" (authClaims res) - utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds - sciToInt = fromMaybe 0 . Sci.toBoundedInteger - case expireJSON of - Just (JSON.Number seconds) -> Just $ TimeSpec (sciToInt seconds - utcToSecs utc) 0 - _ -> Just $ TimeSpec (fromIntegral maxLifetime :: Int64) 0 - authResultKey :: Vault.Key (Either Error AuthResult) authResultKey = unsafePerformIO Vault.newKey {-# NOINLINE authResultKey #-} diff --git a/src/PostgREST/Auth/Cache.hs b/src/PostgREST/Auth/Cache.hs new file mode 100644 index 0000000000..1405c1885d --- /dev/null +++ b/src/PostgREST/Auth/Cache.hs @@ -0,0 +1,111 @@ +{-| +Module : PostgREST.Auth.Cache +Description : Cache to store parsed Jwt Authentication Result +-} +module PostgREST.Auth.Cache + ( getJWTFromCache ) + where + +import qualified Data.Aeson as JSON +import qualified Data.Aeson.KeyMap as KM +import qualified Data.Cache as C +import qualified Data.Scientific as Sci + +import Data.Maybe (fromJust) +import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds) +import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds) +import System.Clock (TimeSpec (..)) +import System.IO.Unsafe (unsafePerformIO) + +import PostgREST.AppState (AppState, getJwtCache, getObserver) +import PostgREST.Auth.Types (AuthResult (..)) +import PostgREST.Error (Error (..)) +import PostgREST.Internal (recursiveSizeNF) +import PostgREST.Observation (Observation (..)) + +import Protolude + +-- | Used to retrieve and insert JWT to JWT Cache +getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult) +getJWTFromCache appState token maxLifetime parseJwt utc = do + + checkCache <- C.lookup jwtCache token + authResult <- maybe parseJwt (pure . Right) checkCache + + -- if token not found, add to cache and increment cache size metric + case (authResult,checkCache) of + -- From comment: + -- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914 + -- + -- We purge expired cache entries on a cache miss + -- The reasoning is that: + -- + -- 1. We expect it to be rare (otherwise there is no point of the cache) + -- 2. It makes sure the cache is not growing (as inserting new entries + -- does garbage collection) + -- 3. Since this is time expiration based cache there is no real risk of + -- starvation - sooner or later we are going to have a cache miss. + + (Right res, Nothing) -> do -- cache miss + + let timeSpec = getTimeSpec res maxLifetime utc + + -- purge expired cache entries (VERY INEFFICIENT) + C.purgeExpired jwtCache + + -- insert new cache entry + C.insert' jwtCache (Just timeSpec) token res + + -- calculate cache size (VERY INEFFICIENT) + cacheSize <- calcCacheSizeInBytes jwtCache + + -- log cache size + observer $ JWTCache cacheSize + + _ -> pure () + + return authResult + where + observer = getObserver appState + jwtCache = getJwtCache appState + +-- Used to extract JWT exp claim and add to JWT Cache +getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec +getTimeSpec res maxLifetime utc = do + let expireJSON = KM.lookup "exp" (authClaims res) + utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds + sciToInt = fromMaybe 0 . Sci.toBoundedInteger + case expireJSON of + Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0 + _ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0 + +-- | Calculate JWT Cache Size in Bytes +-- +-- The cache size is updated by calculating the size of every +-- cache entry and updating the metric. +-- +-- The cache entry consists of +-- key :: ByteString +-- value :: AuthReult +-- expire value :: TimeSpec +-- +-- We calculate the size of each cache entry component +-- by using recursiveSizeNF function which first evaluates +-- the data structure to Normal Form and then calculate size. +-- The normal form evaluation is necessary for accurate size +-- calculation because haskell is lazy and we dont wanna count +-- the size of large thunks (unevaluated expressions) +calcCacheSizeInBytes :: C.Cache ByteString AuthResult -> IO Int +calcCacheSizeInBytes jwtCache = do + cacheList <- C.toList jwtCache + let szList = [ unsafePerformIO (getSize (bs, ar, fromJust ts)) | (bs, ar, ts) <- cacheList] + return $ fromIntegral (sum szList) + where + getSize :: (ByteString, AuthResult, TimeSpec) -> IO Word + getSize (bs, ar, ts) = do + keySize <- recursiveSizeNF bs + arClaimsSize <- recursiveSizeNF $ authClaims ar + arRoleSize <- recursiveSizeNF $ authRole ar + timeSpecSize <- liftA2 (+) (recursiveSizeNF (sec ts)) (recursiveSizeNF (nsec ts)) + + return (keySize + arClaimsSize + arRoleSize + timeSpecSize) diff --git a/src/PostgREST/Internal.hs b/src/PostgREST/Internal.hs new file mode 100644 index 0000000000..c73f25edaf --- /dev/null +++ b/src/PostgREST/Internal.hs @@ -0,0 +1,93 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE MagicHash #-} +{-# LANGUAGE UnboxedTuples #-} +{- | +Module : PostgREST.Internal +Copyright : (c) Dennis Felsing +License : 3-Clause BSD-style +Maintainer : dennis@felsing.org + +https://hackage.haskell.org/package/ghc-datasize + +This vendored dependency can be removed once https://github.com/PostgREST/postgrest/issues/3881 is solved. +-} +module PostgREST.Internal + ( recursiveSizeNF ) + where + +import GHC.Exts +import GHC.Exts.Heap hiding (size) +import GHC.Exts.Heap.Constants (wORD_SIZE) + +import System.Mem + +import Protolude + +-- Inspired by Simon Marlow: +-- https://ghcmutterings.wordpress.com/2009/02/12/53/ + +-- | Calculate size of GHC objects in Bytes. Note that an object may not be +-- evaluated yet and only the size of the initial closure is returned. +closureSize :: a -> IO Word +closureSize x = do + rawWds <- getClosureRawWords x + return . fromIntegral $ length rawWds * wORD_SIZE + +-- | Calculate the recursive size of GHC objects in Bytes. Note that the actual +-- size in memory is calculated, so shared values are only counted once. +-- +-- Call with +-- @ +-- recursiveSize $! 2 +-- @ +-- to force evaluation to WHNF before calculating the size. +-- +-- Call with +-- @ +-- recursiveSize $!! \"foobar\" +-- @ +-- ($!! from Control.DeepSeq) to force full evaluation before calculating the +-- size. +-- +-- A garbage collection is performed before the size is calculated, because +-- the garbage collector would make heap walks difficult. +-- +-- This function works very quickly on small data structures, but can be slow +-- on large and complex ones. If speed is an issue it's probably possible to +-- get the exact size of a small portion of the data structure and then +-- estimate the total size from that. + +recursiveSize :: a -> IO Word +recursiveSize x = do + performGC + fmap snd $ go ([], 0) $ asBox x + where + go (!vs, !acc) b@(Box y) = do + isElem <- or <$> mapM (areBoxesEqual b) vs + if isElem + then return (vs, acc) + else do + size <- closureSize y + closure <- getClosureData y + foldM go (b : vs, acc + size) $ allClosures closure + +-- | Calculate the recursive size of GHC objects in Bytes after calling +-- Control.DeepSeq.force on the data structure to force it into Normal Form. +-- Using this function requires that the data structure has an `NFData` +-- typeclass instance. + +recursiveSizeNF :: NFData a => a -> IO Word +recursiveSizeNF x = recursiveSize $!! x + +-- | Adapted from 'GHC.Exts.Heap.getClosureRaw' which isn't exported. +-- +-- This returns the raw words of the closure on the heap. Once back in the +-- Haskell world, the raw words that hold pointers may be outdated after a +-- garbage collector run. +getClosureRawWords :: a -> IO [Word] +getClosureRawWords x = do + case unpackClosure# x of + (# _iptr, dat, _pointers #) -> do + let nelems = I# (sizeofByteArray# dat) `div` wORD_SIZE + end = nelems - 1 + pure [W# (indexWordArray# dat i) | I# i <- [0.. end] ] diff --git a/src/PostgREST/Logger.hs b/src/PostgREST/Logger.hs index c224f74c79..6ebe06e54c 100644 --- a/src/PostgREST/Logger.hs +++ b/src/PostgREST/Logger.hs @@ -88,6 +88,9 @@ observationLogger loggerState logLevel obs = case obs of o@(HasqlPoolObs _) -> do when (logLevel >= LogDebug) $ do logWithZTime loggerState $ observationMessage o + o@(JWTCache _) -> do + when (logLevel >= LogInfo) $ do + logWithZTime loggerState $ observationMessage o PoolRequest -> pure () PoolRequestFullfilled -> diff --git a/src/PostgREST/Metrics.hs b/src/PostgREST/Metrics.hs index 3999e43d83..0a94ad899b 100644 --- a/src/PostgREST/Metrics.hs +++ b/src/PostgREST/Metrics.hs @@ -1,5 +1,5 @@ {-| -Module : PostgREST.Logger +Module : PostgREST.Metrics Description : Metrics based on the Observation module. See Observation.hs. -} module PostgREST.Metrics @@ -19,7 +19,7 @@ import PostgREST.Observation import Protolude data MetricsState = - MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge + MetricsState Counter Gauge Gauge Gauge (Vector Label1 Counter) Gauge Gauge init :: Int -> IO MetricsState init configDbPoolSize = do @@ -29,12 +29,13 @@ init configDbPoolSize = do poolMaxSize <- register $ gauge (Info "pgrst_db_pool_max" "Max pool connections") schemaCacheLoads <- register $ vector "status" $ counter (Info "pgrst_schema_cache_loads_total" "The total number of times the schema cache was loaded") schemaCacheQueryTime <- register $ gauge (Info "pgrst_schema_cache_query_time_seconds" "The query time in seconds of the last schema cache load") + jwtCacheSize <- register $ gauge (Info "pgrst_jwt_cache_size_bytes" "The JWT cache size in bytes") setGauge poolMaxSize (fromIntegral configDbPoolSize) - pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime + pure $ MetricsState poolTimeouts poolAvailable poolWaiting poolMaxSize schemaCacheLoads schemaCacheQueryTime jwtCacheSize -- Only some observations are used as metrics observationMetrics :: MetricsState -> ObservationHandler -observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime) obs = case obs of +observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schemaCacheLoads schemaCacheQueryTime jwtCacheSize) obs = case obs of (PoolAcqTimeoutObs _) -> do incCounter poolTimeouts (HasqlPoolObs (SQL.ConnectionObservation _ status)) -> case status of @@ -54,6 +55,8 @@ observationMetrics (MetricsState poolTimeouts poolAvailable poolWaiting _ schema setGauge schemaCacheQueryTime resTime SchemaCacheErrorObs _ -> do withLabel schemaCacheLoads "FAIL" incCounter + JWTCache cacheSize -> do + setGauge jwtCacheSize (fromIntegral cacheSize) _ -> pure () diff --git a/src/PostgREST/Observation.hs b/src/PostgREST/Observation.hs index 18fbf558d7..5ac9481f54 100644 --- a/src/PostgREST/Observation.hs +++ b/src/PostgREST/Observation.hs @@ -57,6 +57,7 @@ data Observation | HasqlPoolObs SQL.Observation | PoolRequest | PoolRequestFullfilled + | JWTCache Int data ObsFatalError = ServerAuthError | ServerPgrstBug | ServerError42P05 | ServerError08P01 @@ -138,6 +139,7 @@ observationMessage = \case SQL.ReleaseConnectionTerminationReason -> "release" SQL.NetworkErrorConnectionTerminationReason _ -> "network error" -- usage error is already logged, no need to repeat the same message. ) + JWTCache sz -> "The JWT Cache size updated to " <> show sz <> " bytes" _ -> mempty where showMillis :: Double -> Text diff --git a/test/io/test_io.py b/test/io/test_io.py index ff5619fb97..effb5c1960 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -1632,6 +1632,8 @@ def test_admin_metrics(defaultenv): assert "pgrst_db_pool_available" in response.text assert "pgrst_db_pool_timeouts_total" in response.text + assert "pgrst_jwt_cache_size_bytes" in response.text + def test_schema_cache_startup_load_with_in_db_config(defaultenv, metapostgrest): "verify that the Schema Cache loads correctly at startup, using the in-db `pgrst.db_schemas` config" @@ -1715,3 +1717,35 @@ def test_pgrst_log_503_client_error_to_stderr(defaultenv): log_message = '{"code":"PGRST001","details":"no connection to the server\\n","hint":null,"message":"Database client error. Retrying the connection."}\n' assert any(log_message in line for line in output) + + +def test_jwt_cache_size_bytes_update_log(defaultenv): + "JWT cache size should update on a cache miss" + + env = { + **defaultenv, + "PGRST_LOG_LEVEL": "debug", + "PGRST_JWT_CACHE_MAX_LIFETIME": "86400", + "PGRST_JWT_SECRET": SECRET, + } + + headers = jwtauthheader({"role": "postgrest_test_author"}, SECRET) + + with run(env=env) as postgrest: + response = postgrest.session.get("/authors_only", headers=headers) + assert response.status_code == 200 + + output = sorted(postgrest.read_stdout(nlines=3)) + + response = postgrest.admin.get("/metrics") + assert response.status_code == 200 + + # read cache size from metrics + cache_size = float( + re.search(r"pgrst_jwt_cache_size_bytes (\d+)", response.text).group(1) + ) + + assert ( + "The JWT Cache size updated to " + str(int(cache_size)) + " bytes" + in output[2] + )