Skip to content

Commit

Permalink
feat: add metric pgrst_jwt_cache_size_bytes in admin server
Browse files Browse the repository at this point in the history
  • Loading branch information
taimoorzaeem committed Feb 10, 2025
1 parent c96dc3e commit 65f3bc1
Show file tree
Hide file tree
Showing 11 changed files with 272 additions and 56 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- #1536, Add string comparison feature for jwt-role-claim-key - @taimoorzaeem
- #3747, Allow `not_null` value for the `is` operator - @taimoorzaeem
- #2255, Apply `to_tsvector()` explicitly to the full-text search filtered column (excluding `tsvector` types) - @laurenceisla
- #3802, Add metric `pgrst_jwt_cache_size_bytes` in admin server - @taimoorzaeem

### Fixed

Expand Down
14 changes: 14 additions & 0 deletions docs/references/observability.rst
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,20 @@ pgrst_db_pool_max

Max pool connections.

JWT Cache Metric
----------------

Related to the :ref:`jwt_caching`.

pgrst_jwt_cache_size_bytes
~~~~~~~~~~~~~~~~~~~~~~~~~~

======== =======
**Type** Gauge
======== =======

The JWT cache size in bytes.

Traces
======

Expand Down
5 changes: 4 additions & 1 deletion postgrest.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ library
PostgREST.App
PostgREST.AppState
PostgREST.Auth
PostgREST.Auth.Cache
PostgREST.Auth.Types
PostgREST.CLI
PostgREST.Config
Expand Down Expand Up @@ -90,7 +91,8 @@ library
PostgREST.Response.GucHeader
PostgREST.Response.Performance
PostgREST.Version
other-modules: Paths_postgrest
other-modules: PostgREST.Internal
Paths_postgrest
build-depends: base >= 4.9 && < 4.20
, HTTP >= 4000.3.7 && < 4000.5
, Ranged-sets >= 0.3 && < 0.5
Expand All @@ -109,6 +111,7 @@ library
, either >= 4.4.1 && < 5.1
, extra >= 1.7.0 && < 2.0
, fuzzyset >= 0.2.4 && < 0.3
, ghc-heap >= 9.4 && < 9.9
, hasql >= 1.6.1.1 && < 1.7
, hasql-dynamic-statements >= 0.3.1 && < 0.4
, hasql-notifications >= 0.2.2.2 && < 0.2.3
Expand Down
1 change: 0 additions & 1 deletion src/PostgREST/App.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ module PostgREST.App
, run
) where


import Control.Monad.Except (liftEither)
import Data.Either.Combinators (mapLeft)
import Data.Maybe (fromJust)
Expand Down
53 changes: 3 additions & 50 deletions src/PostgREST/Auth.hs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import qualified Data.Aeson.KeyMap as KM
import qualified Data.Aeson.Types as JSON
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.Cache as C
import qualified Data.Scientific as Sci
import qualified Data.Text as T
import qualified Data.Vault.Lazy as Vault
Expand All @@ -40,12 +39,11 @@ import Data.Either.Combinators (mapLeft)
import Data.List (lookup)
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import System.Clock (TimeSpec (..))
import System.IO.Unsafe (unsafePerformIO)
import System.TimeIt (timeItT)

import PostgREST.AppState (AppState, getConfig, getJwtCache,
getTime)
import PostgREST.AppState (AppState, getConfig, getTime)
import PostgREST.Auth.Cache (getJWTFromCache)
import PostgREST.Auth.Types (AuthResult (..))
import PostgREST.Config (AppConfig (..), FilterExp (..), JSPath,
JSPathExp (..))
Expand Down Expand Up @@ -153,7 +151,7 @@ middleware appState app req respond = do
let token = fromMaybe "" $ Wai.extractBearerAuth =<< lookup HTTP.hAuthorization (Wai.requestHeaders req)
parseJwt = runExceptT $ parseToken conf token time >>= parseClaims conf

-- If DbPlanEnabled -> calculate JWT validation time
-- If ServerTimingEnabled -> calculate JWT validation time
-- If JwtCacheMaxLifetime -> cache JWT validation result
req' <- case (configServerTimingEnabled conf, configJwtCacheMaxLifetime conf) of
(True, 0) -> do
Expand All @@ -174,51 +172,6 @@ middleware appState app req respond = do

app req' respond

-- | Used to retrieve and insert JWT to JWT Cache
getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
getJWTFromCache appState token maxLifetime parseJwt utc = do
checkCache <- C.lookup (getJwtCache appState) token
authResult <- maybe parseJwt (pure . Right) checkCache

case (authResult,checkCache) of
-- From comment:
-- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914
--
-- We purge expired cache entries on a cache miss
-- The reasoning is that:
--
-- 1. We expect it to be rare (otherwise there is no point of the cache)
-- 2. It makes sure the cache is not growing (as inserting new entries
-- does garbage collection)
-- 3. Since this is time expiration based cache there is no real risk of
-- starvation - sooner or later we are going to have a cache miss.

(Right res, Nothing) -> do -- cache miss

let timeSpec = getTimeSpec res maxLifetime utc

-- purge expired cache entries
C.purgeExpired jwtCache

-- insert new cache entry
C.insert' jwtCache timeSpec token res

_ -> pure ()

return authResult
where
jwtCache = getJwtCache appState

-- Used to extract JWT exp claim and add to JWT Cache
getTimeSpec :: AuthResult -> Int -> UTCTime -> Maybe TimeSpec
getTimeSpec res maxLifetime utc = do
let expireJSON = KM.lookup "exp" (authClaims res)
utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
case expireJSON of
Just (JSON.Number seconds) -> Just $ TimeSpec (sciToInt seconds - utcToSecs utc) 0
_ -> Just $ TimeSpec (fromIntegral maxLifetime :: Int64) 0

authResultKey :: Vault.Key (Either Error AuthResult)
authResultKey = unsafePerformIO Vault.newKey
{-# NOINLINE authResultKey #-}
Expand Down
111 changes: 111 additions & 0 deletions src/PostgREST/Auth/Cache.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
{-|
Module : PostgREST.Auth.Cache
Description : Cache to store parsed Jwt Authentication Result
-}
module PostgREST.Auth.Cache
( getJWTFromCache )
where

import qualified Data.Aeson as JSON
import qualified Data.Aeson.KeyMap as KM
import qualified Data.Cache as C
import qualified Data.Scientific as Sci

import Data.Maybe (fromJust)
import Data.Time.Clock (UTCTime, nominalDiffTimeToSeconds)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds)
import System.Clock (TimeSpec (..))
import System.IO.Unsafe (unsafePerformIO)

import PostgREST.AppState (AppState, getJwtCache, getObserver)
import PostgREST.Auth.Types (AuthResult (..))
import PostgREST.Error (Error (..))
import PostgREST.Internal (recursiveSizeNF)
import PostgREST.Observation (Observation (..))

import Protolude

-- | Used to retrieve and insert JWT to JWT Cache
getJWTFromCache :: AppState -> ByteString -> Int -> IO (Either Error AuthResult) -> UTCTime -> IO (Either Error AuthResult)
getJWTFromCache appState token maxLifetime parseJwt utc = do

checkCache <- C.lookup jwtCache token
authResult <- maybe parseJwt (pure . Right) checkCache

-- if token not found, add to cache and increment cache size metric
case (authResult,checkCache) of
-- From comment:
-- https://github.com/PostgREST/postgrest/pull/3801#discussion_r1857987914
--
-- We purge expired cache entries on a cache miss
-- The reasoning is that:
--
-- 1. We expect it to be rare (otherwise there is no point of the cache)
-- 2. It makes sure the cache is not growing (as inserting new entries
-- does garbage collection)
-- 3. Since this is time expiration based cache there is no real risk of
-- starvation - sooner or later we are going to have a cache miss.

(Right res, Nothing) -> do -- cache miss

let timeSpec = getTimeSpec res maxLifetime utc

-- purge expired cache entries (VERY INEFFICIENT)
C.purgeExpired jwtCache

-- insert new cache entry
C.insert' jwtCache (Just timeSpec) token res

-- calculate cache size (VERY INEFFICIENT)
cacheSize <- calcCacheSizeInBytes jwtCache

-- log cache size
observer $ JWTCache cacheSize

_ -> pure ()

return authResult
where
observer = getObserver appState
jwtCache = getJwtCache appState

-- Used to extract JWT exp claim and add to JWT Cache
getTimeSpec :: AuthResult -> Int -> UTCTime -> TimeSpec
getTimeSpec res maxLifetime utc = do
let expireJSON = KM.lookup "exp" (authClaims res)
utcToSecs = floor . nominalDiffTimeToSeconds . utcTimeToPOSIXSeconds
sciToInt = fromMaybe 0 . Sci.toBoundedInteger
case expireJSON of
Just (JSON.Number seconds) -> TimeSpec (sciToInt seconds - utcToSecs utc) 0
_ -> TimeSpec (fromIntegral maxLifetime :: Int64) 0

-- | Calculate JWT Cache Size in Bytes
--
-- The cache size is updated by calculating the size of every
-- cache entry and updating the metric.
--
-- The cache entry consists of
-- key :: ByteString
-- value :: AuthReult
-- expire value :: TimeSpec
--
-- We calculate the size of each cache entry component
-- by using recursiveSizeNF function which first evaluates
-- the data structure to Normal Form and then calculate size.
-- The normal form evaluation is necessary for accurate size
-- calculation because haskell is lazy and we dont wanna count
-- the size of large thunks (unevaluated expressions)
calcCacheSizeInBytes :: C.Cache ByteString AuthResult -> IO Int
calcCacheSizeInBytes jwtCache = do
cacheList <- C.toList jwtCache
let szList = [ unsafePerformIO (getSize (bs, ar, fromJust ts)) | (bs, ar, ts) <- cacheList]
return $ fromIntegral (sum szList)
where
getSize :: (ByteString, AuthResult, TimeSpec) -> IO Word
getSize (bs, ar, ts) = do
keySize <- recursiveSizeNF bs
arClaimsSize <- recursiveSizeNF $ authClaims ar
arRoleSize <- recursiveSizeNF $ authRole ar
timeSpecSize <- liftA2 (+) (recursiveSizeNF (sec ts)) (recursiveSizeNF (nsec ts))

return (keySize + arClaimsSize + arRoleSize + timeSpecSize)
93 changes: 93 additions & 0 deletions src/PostgREST/Internal.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{- |
Module : PostgREST.Internal
Copyright : (c) Dennis Felsing
License : 3-Clause BSD-style
Maintainer : [email protected]
https://hackage.haskell.org/package/ghc-datasize
This vendored dependency can be removed once https://github.com/PostgREST/postgrest/issues/3881 is solved.
-}
module PostgREST.Internal
( recursiveSizeNF )
where

import GHC.Exts
import GHC.Exts.Heap hiding (size)
import GHC.Exts.Heap.Constants (wORD_SIZE)

import System.Mem

import Protolude

-- Inspired by Simon Marlow:
-- https://ghcmutterings.wordpress.com/2009/02/12/53/

-- | Calculate size of GHC objects in Bytes. Note that an object may not be
-- evaluated yet and only the size of the initial closure is returned.
closureSize :: a -> IO Word
closureSize x = do
rawWds <- getClosureRawWords x
return . fromIntegral $ length rawWds * wORD_SIZE

-- | Calculate the recursive size of GHC objects in Bytes. Note that the actual
-- size in memory is calculated, so shared values are only counted once.
--
-- Call with
-- @
-- recursiveSize $! 2
-- @
-- to force evaluation to WHNF before calculating the size.
--
-- Call with
-- @
-- recursiveSize $!! \"foobar\"
-- @
-- ($!! from Control.DeepSeq) to force full evaluation before calculating the
-- size.
--
-- A garbage collection is performed before the size is calculated, because
-- the garbage collector would make heap walks difficult.
--
-- This function works very quickly on small data structures, but can be slow
-- on large and complex ones. If speed is an issue it's probably possible to
-- get the exact size of a small portion of the data structure and then
-- estimate the total size from that.

recursiveSize :: a -> IO Word
recursiveSize x = do
performGC
fmap snd $ go ([], 0) $ asBox x
where
go (!vs, !acc) b@(Box y) = do
isElem <- or <$> mapM (areBoxesEqual b) vs
if isElem
then return (vs, acc)
else do
size <- closureSize y
closure <- getClosureData y
foldM go (b : vs, acc + size) $ allClosures closure

-- | Calculate the recursive size of GHC objects in Bytes after calling
-- Control.DeepSeq.force on the data structure to force it into Normal Form.
-- Using this function requires that the data structure has an `NFData`
-- typeclass instance.

recursiveSizeNF :: NFData a => a -> IO Word
recursiveSizeNF x = recursiveSize $!! x

-- | Adapted from 'GHC.Exts.Heap.getClosureRaw' which isn't exported.
--
-- This returns the raw words of the closure on the heap. Once back in the
-- Haskell world, the raw words that hold pointers may be outdated after a
-- garbage collector run.
getClosureRawWords :: a -> IO [Word]
getClosureRawWords x = do
case unpackClosure# x of
(# _iptr, dat, _pointers #) -> do
let nelems = I# (sizeofByteArray# dat) `div` wORD_SIZE
end = nelems - 1
pure [W# (indexWordArray# dat i) | I# i <- [0.. end] ]
3 changes: 3 additions & 0 deletions src/PostgREST/Logger.hs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ observationLogger loggerState logLevel obs = case obs of
o@(HasqlPoolObs _) -> do
when (logLevel >= LogDebug) $ do
logWithZTime loggerState $ observationMessage o
o@(JWTCache _) -> do
when (logLevel >= LogInfo) $ do
logWithZTime loggerState $ observationMessage o
PoolRequest ->
pure ()
PoolRequestFullfilled ->
Expand Down
Loading

0 comments on commit 65f3bc1

Please sign in to comment.