diff --git a/nix/tools/tests.nix b/nix/tools/tests.nix index 2bbac03d39..18ba0940c7 100644 --- a/nix/tools/tests.nix +++ b/nix/tools/tests.nix @@ -78,6 +78,7 @@ let ioTestPython = python3.withPackages (ps: [ ps.pyjwt + ps.psycopg ps.pytest ps.pytest-xdist ps.pyyaml diff --git a/postgrest.cabal b/postgrest.cabal index 63840ec096..2890fe598f 100644 --- a/postgrest.cabal +++ b/postgrest.cabal @@ -65,6 +65,7 @@ library PostgREST.SchemaCache.Relationship PostgREST.SchemaCache.Representations PostgREST.SchemaCache.Table + PostgREST.SqlTransaction PostgREST.Error PostgREST.Error.Types PostgREST.Listener diff --git a/src/PostgREST/Config.hs b/src/PostgREST/Config.hs index bf48d156bd..fe7d20e964 100644 --- a/src/PostgREST/Config.hs +++ b/src/PostgREST/Config.hs @@ -127,6 +127,7 @@ data AppConfig = AppConfig , configRoleSettings :: RoleSettings , configRoleIsoLvl :: RoleIsolationLvl , configInternalSCQuerySleep :: Maybe Int32 + , configInternalSCLockId :: Maybe Int32 } data LogLevel = LogCrit | LogError | LogWarn | LogInfo | LogDebug @@ -326,6 +327,7 @@ parser optPath env dbSettings roleSettings roleIsolationLvl = <*> pure roleSettings <*> pure roleIsolationLvl <*> optInt "internal-schema-cache-query-sleep" + <*> optInt "internal-schema-cache-lock-id" where parseErrorVerbosity :: C.Key -> C.Parser C.Config Verbosity parseErrorVerbosity k = diff --git a/src/PostgREST/Logger.hs b/src/PostgREST/Logger.hs index b40585ca80..144ec00712 100644 --- a/src/PostgREST/Logger.hs +++ b/src/PostgREST/Logger.hs @@ -32,7 +32,6 @@ import PostgREST.Debounce (makeDebouncer) import PostgREST.Logger.Apache (apacheFormat) import PostgREST.Observation import PostgREST.Query (MainQuery (..)) -import PostgREST.SchemaCache (queryTimingsWLabels) import qualified Data.ByteString.Lazy as LBS import qualified Data.Text as T @@ -160,7 +159,7 @@ observationMessages = \case <> ". " <> jsonMessage usageErr SchemaCacheQueriedObs resultTime timings -> [ "Schema cache queried in " <> showMillis resultTime <> " milliseconds " ] <> - let showTimings qt = [ T.intercalate ", " $ (\(l, v) -> T.decodeUtf8 l <> ": " <> v <> " ms") <$> queryTimingsWLabels qt ] in + let showTimings qt = [ T.intercalate ", " $ (\(l, v) -> T.decodeUtf8 l <> ": " <> v <> " ms") <$> qt ] in maybe mempty showTimings timings SchemaCacheLoadedObs resultTime summary -> [ diff --git a/src/PostgREST/Observation.hs b/src/PostgREST/Observation.hs index cabf8fdd26..1a05880e4e 100644 --- a/src/PostgREST/Observation.hs +++ b/src/PostgREST/Observation.hs @@ -19,7 +19,7 @@ import Network.HTTP.Types.Status (Status) import qualified Network.Wai as Wai import PostgREST.Config.PgVersion import PostgREST.Query (MainQuery) -import PostgREST.SchemaCache (QueryTimings) +import PostgREST.SqlTransaction (QueryTimings) import Protolude hiding (toList) diff --git a/src/PostgREST/SchemaCache.hs b/src/PostgREST/SchemaCache.hs index a487b70089..9165b322d0 100644 --- a/src/PostgREST/SchemaCache.hs +++ b/src/PostgREST/SchemaCache.hs @@ -8,15 +8,20 @@ The schema cache is necessary for resource embedding, foreign keys are used for These queries are executed once at startup or when PostgREST is reloaded. -} +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE DeriveAnyClass #-} {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE PolyKinds #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TypeSynonymInstances #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} module PostgREST.SchemaCache ( SchemaCache(..) @@ -24,14 +29,11 @@ module PostgREST.SchemaCache , querySchemaCache , showSummary , decodeFuncs - , QueryTimings(..) - , queryTimingsWLabels ) where import Data.Aeson ((.=)) import qualified Data.Aeson as JSON -import qualified Data.ByteString.Char8 as BS import qualified Data.HashMap.Strict as HM import qualified Data.HashMap.Strict.InsOrd as HMI import qualified Data.Set as S @@ -39,36 +41,41 @@ import qualified Data.Text as T import qualified Hasql.Decoders as HD import qualified Hasql.Encoders as HE import qualified Hasql.Statement as SQL -import qualified Hasql.Transaction as SQL +import qualified Hasql.Transaction as SQL hiding (sql, + statement) import Data.Functor.Contravariant ((>$<)) import NeatInterpolation (trimming) -import PostgREST.Config (AppConfig (..), - LogLevel (..)) -import PostgREST.Config.Database (TimezoneNames, - toIsolationLevel) -import PostgREST.SchemaCache.Identifiers (FieldName, - QualifiedIdentifier (..), - RelIdentifier (..), - Schema, escapeIdent, - isAnyElement) -import PostgREST.SchemaCache.Relationship (Cardinality (..), - Junction (..), - Relationship (..), - RelationshipsMap) -import PostgREST.SchemaCache.Representations (DataRepresentation (..), - RepresentationsMap) -import PostgREST.SchemaCache.Routine (FuncVolatility (..), - MediaHandler (..), - MediaHandlerMap, - PgType (..), - RetType (..), - Routine (..), - RoutineMap, - RoutineParam (..)) -import PostgREST.SchemaCache.Table (Column (..), ColumnMap, - Table (..), TablesMap) +import PostgREST.Config (AppConfig (..), + LogLevel (..)) +import PostgREST.Config.Database (TimezoneNames, + toIsolationLevel) +import PostgREST.SchemaCache.Identifiers (FieldName, + QualifiedIdentifier (..), + RelIdentifier (..), + Schema, + escapeIdent, + isAnyElement) +import PostgREST.SchemaCache.Relationship (Cardinality (..), + Junction (..), + Relationship (..), + RelationshipsMap) +import PostgREST.SchemaCache.Representations (DataRepresentation (..), + RepresentationsMap) +import PostgREST.SchemaCache.Routine (FuncVolatility (..), + MediaHandler (..), + MediaHandlerMap, + PgType (..), + RetType (..), + Routine (..), + RoutineMap, + RoutineParam (..)) +import PostgREST.SchemaCache.Table (Column (..), + ColumnMap, + Table (..), + TablesMap) +import qualified PostgREST.SqlTransaction as SQL import qualified PostgREST.MediaType as MediaType @@ -153,48 +160,47 @@ type SqlQuery = ByteString maxDbTablesForFuzzySearch :: Int maxDbTablesForFuzzySearch = 500 -querySchemaCache :: AppConfig -> SQL.Transaction (SchemaCache, Maybe QueryTimings) -querySchemaCache conf@AppConfig{..} = do - SQL.sql "set local schema ''" -- This voids the search path. The following queries need this for getting the fully qualified name(schema.name) of every db object - tabs <- sqlTimedStmt gucTbls conf allTables - keyDeps <- sqlTimedStmt gucKDeps conf allViewsKeyDependencies - m2oRels <- sqlTimedStmt gucRels mempty allM2OandO2ORels - funcs <- sqlTimedStmt gucFuncs conf allFunctions - cRels <- sqlTimedStmt gucCRels mempty allComputedRels - reps <- sqlTimedStmt gucDReps conf dataRepresentations - mHdlers <- sqlTimedStmt gucMHdrs conf mediaHandlers - tzones <- if configDbTimezoneEnabled - then sqlTimedStmt gucTzones mempty timezones - else pure S.empty - _ <- - let sleepCall = SQL.Statement "select pg_sleep($1 / 1000.0)" (param HE.int4) HD.noResult True in - for_ configInternalSCQuerySleep (`SQL.statement` sleepCall) -- only used for testing - - qsTime <- - if isLogDebug - then Just <$> SQL.statement mempty (extractTimings configDbTimezoneEnabled) - else pure Nothing - - let tabsWViewsPks = addViewPrimaryKeys tabs keyDeps - rels = addInverseRels $ addM2MRels tabsWViewsPks $ addViewM2OAndO2ORels keyDeps m2oRels - - return (removeInternal schemas $ SchemaCache { - dbTables = tabsWViewsPks - , dbRelationships = getOverrideRelationshipsMap rels cRels - , dbRoutines = funcs - , dbRepresentations = reps - , dbMediaHandlers = HM.union mHdlers initialMediaHandlers -- the custom handlers will override the initial ones - , dbTimezones = tzones - - , dbTablesFuzzyIndex = - -- Only build fuzzy index for schemas with a reasonable number of tables - -- Fuzzy.FuzzySet is memory heavy we just don't use it for large schemas - Fuzzy.fromList <$> HM.filter ((< maxDbTablesForFuzzySearch) . length) (HM.fromListWith (<>) ((qiSchema &&& pure . qiName) <$> HM.keys tabsWViewsPks)) - }, qsTime) +querySchemaCache :: AppConfig -> SQL.Transaction (SchemaCache, Maybe SQL.QueryTimings) +querySchemaCache conf@AppConfig{..} = + -- if configInternalSCLockId is set run queries step-by-step waiting for lock release before each + SQL.runSteppedTransaction @SchemaCacheLabel configInternalSCLockId $ + -- if log level is debug then time queries + SQL.runTimed @SchemaCacheLabel isLogDebug $ do + SQL.sql @NoStep "set local schema ''" -- This voids the search path. The following queries need this for getting the fully qualified name(schema.name) of every db object + tabs <- SQL.statement @Tables conf allTables + keyDeps <- SQL.statement @KeyDependencies conf allViewsKeyDependencies + m2oRels <- SQL.statement @Relationships mempty allM2OandO2ORels + funcs <- SQL.statement @Functions conf allFunctions + cRels <- SQL.statement @ComputedRelationships mempty allComputedRels + reps <- SQL.statement @DataRepresentations conf dataRepresentations + mHdlers <- SQL.statement @MediaHandlers conf mediaHandlers + tzones <- if configDbTimezoneEnabled + then SQL.statement @Timezones mempty timezones + else pure S.empty + + let tabsWViewsPks = addViewPrimaryKeys tabs keyDeps + rels = addInverseRels $ addM2MRels tabsWViewsPks $ addViewM2OAndO2ORels keyDeps m2oRels + + return $ removeInternal schemas $ SchemaCache { + dbTables = tabsWViewsPks + , dbRelationships = getOverrideRelationshipsMap rels cRels + , dbRoutines = funcs + , dbRepresentations = reps + , dbMediaHandlers = HM.union mHdlers initialMediaHandlers -- the custom handlers will override the initial ones + , dbTimezones = tzones + + , dbTablesFuzzyIndex = + -- Only build fuzzy index for schemas with a reasonable number of tables + -- Fuzzy.FuzzySet is memory heavy we just don't use it for large schemas + Fuzzy.fromList <$> HM.filter ((< maxDbTablesForFuzzySearch) . length) (HM.fromListWith (<>) ((qiSchema &&& pure . qiName) <$> HM.keys tabsWViewsPks)) + } + -- only used for testing + -- TODO remove configInternalSCQuerySleep once all tests are migrated to stepped execution + <* for_ configInternalSCQuerySleep (\sleep -> SQL.statement @NoStep sleep sleepCall) where schemas = toList configDbSchemas isLogDebug = configLogLevel == LogDebug - sqlTimedStmt = sqlTimedStatement isLogDebug + sleepCall = SQL.Statement "select pg_sleep($1 / 1000.0)" (param HE.int4) HD.noResult True -- | overrides detected relationships with the computed relationships and gets the RelationshipsMap getOverrideRelationshipsMap :: [Relationship] -> [Relationship] -> RelationshipsMap @@ -1155,71 +1161,19 @@ nullableColumn = HD.column . HD.nullable arrayColumn :: HD.Value a -> HD.Row [a] arrayColumn = column . HD.listArray . HD.nonNullable -{- - - Times a sql statement inside a transaction, for this: - - - - 1. We start a timer: select set_config('pgrst.tmp_x', clock_timestamp()::text, false); - - 2. Run the statement: select .... - - 3. End the timer: select set_config('pgrst.tmp_x', (clock_timestamp() - current_setting('pgrst.tmp_x', false)::timestamptz)::text, false); - - - - We can do this for several statements inside the transaction. The timings are later captured at the end of the transaction with extractTimings. - -} -sqlTimedStatement :: Bool -> ByteString -> a -> SQL.Statement a b -> SQL.Transaction b -sqlTimedStatement isLogDebug guc params stmt = - if isLogDebug then - SQL.sql sFrag >> SQL.statement params stmt <* SQL.sql eFrag - else - SQL.statement params stmt - where - sFrag = "select set_config('pgrst." <> guc <> "', clock_timestamp()::text, true)" - eFrag = "select set_config('pgrst." <> guc <> "', (clock_timestamp() - current_setting('pgrst." <> guc <> "', false)::timestamptz)::text, true)" - --- Extract all the generated timings (see sqlTimedStatement) converting the value to milliseconds. -extractTimings :: Bool -> SQL.Statement () QueryTimings -extractTimings hasTimezones = SQL.Statement sql HE.noParams decodeThem True - where - qFrag setting = "extract('milliseconds' from current_setting('pgrst." <> setting <> "', false)::interval)::text" - sql = "SELECT " <> BS.intercalate "," - [ qFrag gucTbls, qFrag gucKDeps, qFrag gucRels - , qFrag gucFuncs, qFrag gucCRels, qFrag gucDReps - , qFrag gucMHdrs, if hasTimezones then qFrag gucTzones else "'0.0'" - ] - decodeThem :: HD.Result QueryTimings - decodeThem = HD.singleRow $ - QueryTimings - <$> column HD.text <*> column HD.text <*> column HD.text - <*> column HD.text <*> column HD.text <*> column HD.text - <*> column HD.text <*> column HD.text - -data QueryTimings = QueryTimings - { qtTables :: Text - , qtKeyDeps :: Text - , qtRels :: Text - , qtFuncs :: Text - , qtCRels :: Text - , qtDReps :: Text - , qtMHdrs :: Text - , qtTzones :: Text - } deriving (Show) - -queryTimingsWLabels :: QueryTimings -> [(ByteString, Text)] -queryTimingsWLabels qt = - [ (gucTbls, qtTables qt) - , (gucKDeps, qtKeyDeps qt) - , (gucRels, qtRels qt) - , (gucFuncs, qtFuncs qt) - , (gucCRels, qtCRels qt) - , (gucDReps, qtDReps qt) - , (gucMHdrs, qtMHdrs qt) - , (gucTzones, qtTzones qt) - ] - -gucTbls, gucKDeps, gucRels, gucFuncs, gucCRels, gucDReps, gucMHdrs, gucTzones :: ByteString -gucTbls = "tables" -gucKDeps = "keydeps" -gucRels = "rels" -gucFuncs = "funcs" -gucCRels = "comprels" -gucDReps = "dreps" -gucMHdrs = "mhandlers" -gucTzones = "tzones" +data SchemaCacheLabel = Step SQL.LockSpec SQL.TimingSpec + +instance SQL.TransactionKind SchemaCacheLabel where + type LabelConstraint SchemaCacheLabel label = (SQL.SqlBreakpoint label, SQL.SqlTiming label) +instance SQL.HasTimingsQueryLabel SchemaCacheLabel where + type TimingsQueryLabel SchemaCacheLabel = NoStep + +type Tables = Step (SQL.Lock 0) (SQL.Timing "tables") +type KeyDependencies = Step (SQL.Lock 1) (SQL.Timing "keydeps") +type Relationships = Step (SQL.Lock 2) (SQL.Timing "rels") +type Functions = Step (SQL.Lock 3) (SQL.Timing "funcs") +type ComputedRelationships = Step (SQL.Lock 4) (SQL.Timing "comprels") +type DataRepresentations = Step (SQL.Lock 5) (SQL.Timing "dreps") +type MediaHandlers = Step (SQL.Lock 6) (SQL.Timing "mhandlers") +type Timezones = Step (SQL.Lock 7) (SQL.Timing "tzones") +type NoStep = Step SQL.NoLock SQL.NoTiming diff --git a/src/PostgREST/SqlTransaction.hs b/src/PostgREST/SqlTransaction.hs new file mode 100644 index 0000000000..ed5228dd66 --- /dev/null +++ b/src/PostgREST/SqlTransaction.hs @@ -0,0 +1,185 @@ +{-# LANGUAGE AllowAmbiguousTypes #-} +{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GeneralisedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE QuantifiedConstraints #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneKindSignatures #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeAbstractions #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE UndecidableSuperClasses #-} + +module PostgREST.SqlTransaction + ( SqlTransaction + , SqlCondemn + , Condemnable + , HasTimingsQueryLabel(..) + , QueryTimings + , TransactionConstraint + , SqlTx + , LockSpec(..) + , SqlBreakpoint(..) + , TimingSpec(..) + , SqlTiming + , TransactionKind(..) + , condemn + , sql + , statement + , runSteppedTransaction + , runTimed + ) where + +import Protolude + +import qualified Control.Monad.Writer.Strict as Writer +import Data.Functor.Contravariant ((>$<)) +import qualified Data.List as L +import qualified Data.Text as T +import qualified Hasql.Decoders as HD +import qualified Hasql.Encoders as HE +import qualified Hasql.Statement as SQL +import qualified Hasql.Transaction as SQL + +type QueryTimings = [(ByteString, Text)] + +class TransactionKind k where + type LabelConstraint k (label :: k) :: Constraint + type TransactionCapability k (m :: Type -> Type) :: Constraint + type TransactionCapability k m = () + +type TransactionConstraint k m = (SqlTransaction k m, TransactionCapability k m) + +class (Monad m, TransactionKind k) => SqlTransaction k (m :: Type -> Type) where + runSql :: forall (label :: k). LabelConstraint k label => ByteString -> m () + runStatement :: forall (label :: k) a b. LabelConstraint k label => a -> SQL.Statement a b -> m b + +type SqlTx k a = forall m. TransactionConstraint k m => m a + +class SqlCondemn (m :: Type -> Type) where + condemn :: m () + +class (SqlTransaction k m, SqlCondemn m) => Condemnable k m +instance (SqlTransaction k m, SqlCondemn m) => Condemnable k m + +sql :: forall {k} (label :: k) m. (SqlTransaction k m, LabelConstraint k label) => ByteString -> m () +sql = runSql @k @m @label + +statement :: forall {k} (label :: k) m a b. (SqlTransaction k m, LabelConstraint k label) => a -> SQL.Statement a b -> m b +statement = runStatement @k @m @label + +instance TransactionKind k => SqlTransaction k SQL.Transaction where + runSql = SQL.sql + runStatement = SQL.statement + +instance SqlCondemn SQL.Transaction where + condemn = SQL.condemn + +instance (Monad m, SqlCondemn m, Writer.MonadTrans t) => SqlCondemn (t m) where + condemn = lift condemn + +newtype WithLock m a = WithLock (ReaderT Int32 m a) + deriving newtype (Functor, Applicative, Monad, Writer.MonadTrans) + +data LockSpec = NoLock | Lock Nat +data TimingSpec = NoTiming | Timing Symbol + +type SqlBreakpoint :: forall {k}. k -> Constraint +class SqlLock (BreakpointLock label) => SqlBreakpoint (label :: k) where + type BreakpointLock label :: LockSpec + +type SqlLock :: LockSpec -> Constraint +class SqlLock lock where + lockNext :: forall k (label :: k) m a. (SqlTransaction k m, LabelConstraint k label) => m a -> WithLock m a + +instance KnownNat lock => SqlLock (Lock lock) where + lockNext @_ @label @m tx = WithLock $ do + lockId <- ask + lift $ + statement @label @m + (lockId, fromIntegral $ natVal (Proxy @lock)) + ( SQL.Statement + "SELECT pg_advisory_xact_lock($1, $2)" + ((fst >$< param HE.int4) <> (snd >$< param HE.int2)) + HD.noResult + False + ) + *> tx + +instance SqlLock NoLock where + lockNext = lift + +instance SqlLock lock => SqlBreakpoint (step lock whatever) where + type BreakpointLock (step lock whatever) = lock + +instance forall k m. (SqlTransaction k m, forall (label :: k). LabelConstraint k label => SqlBreakpoint label) => SqlTransaction k (WithLock m) where + runSql @label query = + lockNext @(BreakpointLock label) @k @label @m $ sql @label @m query + + runStatement @label params stmt = + lockNext @(BreakpointLock label) @k @label @m $ statement @label @m params stmt + +runSteppedTransaction :: forall k m a. (TransactionKind k, TransactionConstraint k m, TransactionConstraint k (WithLock m), forall (label :: k). LabelConstraint k label => SqlBreakpoint label) => Maybe Int32 -> SqlTx k a -> m a +runSteppedTransaction l tx = maybe tx (stepped tx) l + where + stepped (WithLock m) = runReaderT m + +newtype GucTimed k m a = GucTimed (Writer.WriterT [ByteString] m a) + deriving newtype (Functor, Applicative, Monad, Writer.MonadTrans) + +class SqlTiming (label :: k) where + timeTransactionForLabel :: forall m a. (SqlTransaction k m, LabelConstraint k label) => m a -> GucTimed k m a + +instance SqlTiming (step lock NoTiming) where + timeTransactionForLabel = lift + +instance KnownSymbol guc => SqlTiming (step lock (Timing guc)) where + timeTransactionForLabel @m tx = + GucTimed $ Writer.WriterT (pure ((), [gucName])) *> timedTx + where + GucTimed timedTx = lift $ sFrag gucName *> tx <* eFrag gucName + gucName = encodeUtf8 . T.pack $ symbolVal (Proxy @guc) + sFrag name = sql @(step lock (Timing guc)) @m $ "select set_config('pgrst." <> name <> "', clock_timestamp()::text, true)" + eFrag name = sql @(step lock (Timing guc)) @m $ "select set_config('pgrst." <> name <> "', (clock_timestamp() - current_setting('pgrst." <> name <> "', false)::timestamptz)::text, true)" + +instance forall k m. (SqlTransaction k m, forall (label :: k). LabelConstraint k label => SqlTiming label) => SqlTransaction k (GucTimed k m) where + runSql @label query = + timeTransactionForLabel @k @label $ sql @label @m query + + runStatement @label params stmt = + timeTransactionForLabel @k @label $ statement @label @m params stmt + +class HasTimingsQueryLabel k where + type TimingsQueryLabel k :: k + +runTimed :: forall k m a. (TransactionKind k, HasTimingsQueryLabel k, TransactionConstraint k m, TransactionConstraint k (GucTimed k m), LabelConstraint k (TimingsQueryLabel k), forall (label :: k). LabelConstraint k label => SqlTiming label) => Bool -> SqlTx k a -> m (a, Maybe QueryTimings) +runTimed isDebug tx = + if isDebug then timed tx else (, mempty) <$> tx + where + timed :: SqlTransaction k m => GucTimed k m a -> m (a, Maybe QueryTimings) + timed (GucTimed m) = do + (result, labels) <- Writer.runWriterT m + (result,) . Just <$> + statement @(TimingsQueryLabel k) @m + (decodeUtf8 <$> L.nub labels) + ( SQL.Statement + ( "SELECT name, extract('milliseconds' from duration)::text " <> + "FROM (SELECT name, nullif(current_setting('pgrst.' || name, true), '')::interval AS duration " <> + "FROM unnest($1::text[]) AS timing(name)) pgrst_timings " <> + "WHERE duration IS NOT NULL" + ) + (param . HE.foldableArray . HE.nonNullable $ HE.text) + (HD.rowList $ ((,) . encodeUtf8 <$> HD.column (HD.nonNullable HD.text)) <*> HD.column (HD.nonNullable HD.text)) + True + ) + +param :: HE.Value a -> HE.Params a +param = HE.param . HE.nonNullable diff --git a/test/io/conftest.py b/test/io/conftest.py index b6a068222f..3dcfdb701b 100644 --- a/test/io/conftest.py +++ b/test/io/conftest.py @@ -1,7 +1,7 @@ import os import pytest from syrupy.extensions.json import SingleFileSnapshotExtension -from postgrest import run +from postgrest import SchemaCacheLocks, run @pytest.fixture @@ -57,16 +57,13 @@ def replicaenv(defaultenv): @pytest.fixture -def slow_schema_cache_env(defaultenv): - "Slow schema cache load environment PostgREST." - return { - **defaultenv, - "PGRST_INTERNAL_SCHEMA_CACHE_QUERY_SLEEP": "1000", # this does a pg_sleep internally, it will cause the schema cache query to be slow - # the slow schema cache query will keep using one pool connection until it finishes - # to prevent requests waiting for PGRST_DB_POOL_ACQUISITION_TIMEOUT we'll increase the pool size (must be >= 2) - "PGRST_DB_POOL": "2", - "PGRST_DB_CHANNEL_ENABLED": "true", - } +def schema_cache_locks(baseenv): + "Factory for controlling step-by-step querySchemaCache execution." + + def factory(lock_id=None, max_step=15): + return SchemaCacheLocks(baseenv, lock_id=lock_id, max_step=max_step) + + return factory @pytest.fixture diff --git a/test/io/postgrest.py b/test/io/postgrest.py index 65431035cd..adc91457e5 100644 --- a/test/io/postgrest.py +++ b/test/io/postgrest.py @@ -11,7 +11,9 @@ import time import string import urllib.parse +import uuid +import psycopg import requests import requests_unixsocket @@ -58,6 +60,76 @@ def request(self, method, url, *args, **kwargs): return super(PostgrestSession, self).request(method, fullurl, *args, **kwargs) +class SchemaCacheLocks: + """ + Hold advisory locks that make querySchemaCache block before each query. + + Use the generated lock_id as PGRST_INTERNAL_SCHEMA_CACHE_LOCK_ID. The + helper keeps a single database session open, so locks stay held until a + step is explicitly unlocked or the context exits. + """ + + def __init__(self, env, lock_id=None, max_step=15): + self.env = env + self.lock_id = (uuid.uuid4().int % 2147483647) + 1 + if lock_id is not None: + self.lock_id = int(lock_id) + self.max_step = self._step(max_step) + self._conn = None + + def __enter__(self): + self._conn = psycopg.connect( + dbname=self.env["PGDATABASE"], + host=self.env["PGHOST"], + user=self.env["PGUSER"], + autocommit=True, + ) + try: + self.lock() + except Exception: + self._conn.close() + raise + return self + + def __exit__(self, _exc_type, _exc_value, _traceback): + self._conn.close() + + def lock(self): + "Acquire session-level advisory locks for all configured schema-cache steps." + if self.max_step < 0: + return + + with self._conn.cursor() as cursor: + cursor.execute( + """ + SELECT pg_advisory_lock(%s, lock_number) + FROM generate_series(0, %s::int) AS lock_number + """, + (self.lock_id, self.max_step), + ) + + def unlock(self, step): + "Release a schema-cache step and let querySchemaCache run that query." + with self._conn.cursor() as cursor: + cursor.execute( + "SELECT pg_advisory_unlock(%s, %s)", + (self.lock_id, self._step(step)), + ) + + def unlock_all(self): + "Release all schema-cache step locks still held by this session." + if self._conn is not None and not self._conn.closed: + with self._conn.cursor() as cursor: + cursor.execute("SELECT pg_advisory_unlock_all()") + + @staticmethod + def _step(step): + step = int(step) + if step < 0: + raise ValueError("schema-cache lock step must be non-negative") + return step + + @dataclasses.dataclass class PostgrestProcess: "Running PostgREST process and its corresponding main and admin endpoints." diff --git a/test/io/test_io.py b/test/io/test_io.py index a9a7da03ec..5d024832a9 100644 --- a/test/io/test_io.py +++ b/test/io/test_io.py @@ -45,6 +45,22 @@ def psql_as_superuser(query): ) +def wait_for_response(request, condition, max_seconds=2): + "Poll an HTTP request until its response satisfies the given condition." + deadline = time.monotonic() + max_seconds + response = None + + while time.monotonic() < deadline: + response = request() + if condition(response): + return response + time.sleep(0.05) + + assert response is not None + assert condition(response), f"{response.status_code}: {response.text}" + return response + + def test_connect_with_dburi(dburi, defaultenv): "Connecting with db-uri instead of LIPQ* environment variables should work." defaultenv_without_libpq = { @@ -1388,36 +1404,44 @@ def test_isolation_level(defaultenv): assert response.text == '"serializable"' -def test_schema_cache_concurrent_notifications(slow_schema_cache_env): +def test_schema_cache_concurrent_notifications(defaultenv, schema_cache_locks): "schema cache should be up-to-date whenever a notification is sent while another reload is in progress, see https://github.com/PostgREST/postgrest/issues/2791" - internal_sleep = ( - int(slow_schema_cache_env["PGRST_INTERNAL_SCHEMA_CACHE_QUERY_SLEEP"]) / 1000 - ) - - with run(env=slow_schema_cache_env, wait_for=None) as postgrest: - time.sleep(2 * internal_sleep + 0.1) # wait for readiness manually + locks = schema_cache_locks() + env = { + **defaultenv, + "PGRST_INTERNAL_SCHEMA_CACHE_LOCK_ID": str(locks.lock_id), + # the blocked schema cache query keeps using one pool connection until it + # finishes. Keep another connection available for requests that send + # concurrent reload notifications. + "PGRST_DB_POOL": "2", + "PGRST_DB_CHANNEL_ENABLED": "true", + } - # first request, create a function and set a schema cache reload in progress - response = postgrest.session.post("/rpc/create_function") - assert response.text == "" - assert response.status_code == 204 + with run(env=env, wait_max_seconds=5) as postgrest: + with locks: + # first request, create a function and set a schema cache reload in progress + response = postgrest.session.post("/rpc/create_function") + assert response.text == "" + assert response.status_code == 204 - time.sleep( - internal_sleep / 2 - ) # wait to be inside the schema cache reload process + # Let the in-progress schema cache reload query functions before + # sending another reload notification. + for step in range(4): + locks.unlock(step) - # second request, change the same function and do another schema cache reload - response = postgrest.session.post("/rpc/migrate_function") - assert response.text == "" - assert response.status_code == 204 + # second request, change the same function and do another schema cache reload + response = postgrest.session.post("/rpc/migrate_function") + assert response.text == "" + assert response.status_code == 204 - time.sleep( - 2 * internal_sleep - ) # wait enough time to get the final schema cache state + locks.unlock_all() # confirm the schema cache is up-to-date and the 2nd reload wasn't lost - response = postgrest.session.get("/rpc/mult_them?c=3&d=4") + response = wait_for_response( + lambda: postgrest.session.get("/rpc/mult_them?c=3&d=4"), + lambda response: response.text == "12" and response.status_code == 200, + ) assert response.text == "12" assert response.status_code == 200 @@ -1455,12 +1479,10 @@ def test_schema_cache_query_timings_log(level, timezone_enabled, defaultenv): env = { **defaultenv, "PGRST_LOG_LEVEL": level, - # when this is disabled, it should log 0 for tzones "PGRST_DB_TIMEZONE_ENABLED": timezone_enabled, } - # here we also capture the tzones: ms log_pattern = re.compile( - r".+: tables: [\d.]+ ms, keydeps: [\d.]+ ms, rels: [\d.]+ ms, funcs: [\d.]+ ms, comprels: [\d.]+ ms, dreps: [\d.]+ ms, mhandlers: [\d.]+ ms, tzones: ([\d.]+) ms" + r".+: tables: [\d.]+ ms, keydeps: [\d.]+ ms, rels: [\d.]+ ms, funcs: [\d.]+ ms, comprels: [\d.]+ ms, dreps: [\d.]+ ms, mhandlers: [\d.]+ ms(?:, tzones: ([\d.]+) ms)?" ) with run(env=env, no_startup_stdout=False) as postgrest: @@ -1472,8 +1494,9 @@ def test_schema_cache_query_timings_log(level, timezone_enabled, defaultenv): if level == "debug": assert len(timing_matches) == 1 if timezone_enabled == "false": - assert float(timing_matches[0].group(1)) == 0 + assert timing_matches[0].group(1) is None else: + assert timing_matches[0].group(1) is not None assert float(timing_matches[0].group(1)) > 0 else: assert not timing_matches diff --git a/test/observability/ObsHelper.hs b/test/observability/ObsHelper.hs index 268bdbf67c..70ab8f1e55 100644 --- a/test/observability/ObsHelper.hs +++ b/test/observability/ObsHelper.hs @@ -118,6 +118,7 @@ baseCfg = let secret = encodeUtf8 "reallyreallyreallyreallyverysafe" in , configRoleSettings = mempty , configRoleIsoLvl = mempty , configInternalSCQuerySleep = Nothing + , configInternalSCLockId = Nothing , configServerTimingEnabled = True } diff --git a/test/spec/SpecHelper.hs b/test/spec/SpecHelper.hs index 4e45a5a31b..03595c0847 100644 --- a/test/spec/SpecHelper.hs +++ b/test/spec/SpecHelper.hs @@ -159,6 +159,7 @@ baseCfg = let secret = encodeUtf8 "reallyreallyreallyreallyverysafe" in , configRoleSettings = mempty , configRoleIsoLvl = mempty , configInternalSCQuerySleep = Nothing + , configInternalSCLockId = Nothing , configServerTimingEnabled = True }