Skip to content

Commit

Permalink
Allow preloading pgtt in session_preload_libraries
Browse files Browse the repository at this point in the history
The module is now lazily loaded when it might be needed (after parse analysis,
before executor startup and before utility statements execution), and all the
code is simply bypassed if the underlying extension has not been created.
  • Loading branch information
rjuju committed Apr 12, 2024
1 parent 3056141 commit 513b416
Showing 1 changed file with 103 additions and 55 deletions.
158 changes: 103 additions & 55 deletions pgtt.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ static void gtt_post_parse_analyze(ParseState *pstate, Query *query, struct Jumb
#else
static void gtt_post_parse_analyze(ParseState *pstate, Query *query);
#endif
static void gtt_try_load(void);
#if PG_VERSION_NUM < 160000
Oid get_extension_schema(Oid ext_oid);
#endif
Expand Down Expand Up @@ -217,7 +218,7 @@ static Oid gtt_create_table_statement(Gtt gtt);
static void gtt_create_table_as(Gtt gtt, bool skipdata);
static void gtt_unregister_global_temporary_table(Oid relid, const char *relname);
void GttHashTableDeleteAll(void);
void EnableGttManager(void);
bool EnableGttManager(void);
Gtt GetGttByName(const char *name);
static void gtt_load_global_temporary_tables(void);
static Oid create_temporary_table_internal(Oid parent_relid, bool preserved);
Expand All @@ -236,21 +237,11 @@ static void gtt_unregister_gtt_not_cached(const char *relname);
void
_PG_init(void)
{
static bool inited = false;

elog(DEBUG1, "_PG_init() : %d", inited);

/*
* Be sure we do initialization only once.
*
* If initialization fails due to, e.g., ERROR: extension "pgtt" does
* not exist, then we'll return here on the next usage.
*/
if (inited)
return;
elog(DEBUG1, "_PG_init()");

if (ParallelWorkerNumber >= 0)
return;

/*
* If we are loaded via shared_preload_libraries exit.
*/
Expand All @@ -261,31 +252,6 @@ _PG_init(void)
errhint("Use \"LOAD 'pgtt';\" in the running session instead.")));
}

if (!IsTransactionState())
{
ereport(FATAL,
(errmsg("The pgtt extension can not be loaded using session_preload_libraries."),
errhint("Use \"LOAD 'pgtt';\" in the running session instead.")));
}

if (GttHashTable == NULL)
{
/* Initialize list of Global Temporary Table */
EnableGttManager();

/*
* Load temporary table definition from pg_global_temp_tables table
* into our Hash table and pre-create the temporary tables.
*/
gtt_load_global_temporary_tables();
}

/*
* Be sure that extension schema is at end of the search path so that
* "template" tables will be found.
*/
force_pgtt_namespace();

/*
* Define (or redefine) custom GUC variables.
* No custom GUC variable at this time
Expand All @@ -303,8 +269,6 @@ _PG_init(void)
NULL,
NULL);

inited = true;

/*
* Install hooks.
*/
Expand Down Expand Up @@ -351,12 +315,14 @@ gtt_ProcessUtility(GTT_PROCESSUTILITY_PROTO)
/* Do not waste time here if the feature is not enabled for this session */
if (pgtt_is_enabled && NOT_IN_PARALLEL_WORKER)
{
/* Try to load pgtt if not already done. */
gtt_try_load();

/*
* Be sure that extension schema is at end of the search path so that
* "template" tables will be find.
*/
if (IsTransactionState())
force_pgtt_namespace();
force_pgtt_namespace();

/*
* Check if we have a CREATE GLOBAL TEMPORARY TABLE
Expand Down Expand Up @@ -415,6 +381,9 @@ gtt_check_command(GTT_PROCESSUTILITY_PROTO)

elog(DEBUG1, "gtt_check_command() on query: \"%s\"", queryString);

if (GttHashTable == NULL)
return false;

/* Intercept CREATE / DROP TABLE statements */
switch (nodeTag(parsetree))
{
Expand Down Expand Up @@ -999,6 +968,8 @@ gtt_ExecutorStart(QueryDesc *queryDesc, int eflags)
/* Do not waste time here if the feature is not enabled for this session */
if (pgtt_is_enabled && NOT_IN_PARALLEL_WORKER)
{
/* Try to load pgtt if not already done. */
gtt_try_load();

/* check if we are working on a GTT and create it if it doesn't exist */
if (queryDesc->operation == CMD_INSERT
Expand Down Expand Up @@ -1034,7 +1005,7 @@ gtt_table_exists(QueryDesc *queryDesc)
Gtt gtt;
PlannedStmt *pstmt = (PlannedStmt *) queryDesc->plannedstmt;

if (!pstmt)
if (GttHashTable == NULL || !pstmt)
return false;

/* no relation in rtable probably a function call */
Expand Down Expand Up @@ -1295,6 +1266,36 @@ gtt_unregister_gtt_not_cached(const char *relname)
#endif
}

/*
* Check if pgtt hasn't been loaded yet, and try to load it in that case.
*/
static void
gtt_try_load(void)
{
/*
* Don't try to load if the extension is disabled, if we can't do it now or
* if it's already loaded.
*/
if (!pgtt_is_enabled || !IsTransactionState() || GttHashTable != NULL)
return;

/* Initialize list of Global Temporary Table */
if (EnableGttManager())
{
/*
* Load temporary table definition from pg_global_temp_tables table
* into our Hash table and pre-create the temporary tables.
*/
gtt_load_global_temporary_tables();

/*
* Be sure that extension schema is at end of the search path so that
* "template" tables will be found.
*/
force_pgtt_namespace();
}
}

#if PG_VERSION_NUM < 160000
/*
* From src/backend/commands/extension.c
Expand Down Expand Up @@ -1351,10 +1352,30 @@ get_extension_schema(Oid ext_oid)
* EnableGttManager
* Enables the GTT management cache at backend startup.
*/
void
bool
EnableGttManager(void)
{
Oid extOid = get_extension_oid("pgtt", false);
Oid extOid = get_extension_oid("pgtt", true);
RangeVar *rv;
char *nspname;

if (!OidIsValid(extOid))
return false;

pgtt_namespace_oid = get_extension_schema(extOid);
if (!OidIsValid(pgtt_namespace_oid))
elog(ERROR, "namespace %d can not be found.", pgtt_namespace_oid);

/*
* Check if the GTT relation also exist. We might be in the middle of the
* extension creation, where the line in pg_extension exists but not the
* rest of SQL objects.
*/
nspname = get_namespace_name(pgtt_namespace_oid);
rv = makeRangeVar(nspname, CATALOG_GLOBAL_TEMP_REL, -1);

if (!OidIsValid(RangeVarGetRelid(rv, AccessShareLock, true)))
return false;

if (GttHashTable == NULL)
{
Expand All @@ -1378,11 +1399,13 @@ EnableGttManager(void)
elog(DEBUG1, "GTT cache initialized.");
}

/* Set the OID and name of the extension schema, all objects will be created in this schema */
pgtt_namespace_oid = get_extension_schema(extOid);
if (!OidIsValid(pgtt_namespace_oid))
elog(ERROR, "namespace %d can not be found.", pgtt_namespace_oid);
strcpy(pgtt_namespace_name, get_namespace_name(pgtt_namespace_oid));
/*
* Set the OID and name of the extension schema, all objects will be
* created in this schema.
*/
strcpy(pgtt_namespace_name, nspname);

return true;
}

/*
Expand All @@ -1408,18 +1431,23 @@ GttHashTableDeleteAll(void)
/* Restart the iteration in case that led to other drops */
hash_seq_term(&status);
hash_seq_init(&status, GttHashTable);
}
}
}

/*
* GetGttByName
* Returns a Gtt given a table name, or NULL if name is not found.
*
* Caller should have made sure that GTT has been properly loaded before
* calling this function.
*/
Gtt
GetGttByName(const char *name)
{
Gtt gtt;

Assert(GttHashTable != NULL);

if (PointerIsValid(name))
GttHashTableLookup(name, gtt);

Expand Down Expand Up @@ -1450,10 +1478,12 @@ gtt_load_global_temporary_tables(void)

/* Set and open the GTT definition storage relation */
rv = makeRangeVar(pgtt_namespace_name, CATALOG_GLOBAL_TEMP_REL, -1);

/* Open the CATALOG_GLOBAL_TEMP_REL table. We don't want to allow
* writable accesses by other session during import. */
PushActiveSnapshot(GetTransactionSnapshot());
snapshot = GetActiveSnapshot();
//snapshot = GetTransactionSnapshot();

#if (PG_VERSION_NUM >= 120000)
rel = table_openrv(rv, AccessShareLock);
scan = table_beginscan(rel, snapshot, 0, (ScanKey) NULL);
Expand Down Expand Up @@ -1489,6 +1519,8 @@ gtt_load_global_temporary_tables(void)
heap_endscan(scan);
heap_close(rel, AccessShareLock);
#endif

PopActiveSnapshot();
}

static Oid
Expand Down Expand Up @@ -1733,7 +1765,11 @@ gtt_post_parse_analyze(ParseState *pstate, Query *query, struct JumbleState * js
gtt_post_parse_analyze(ParseState *pstate, Query *query)
#endif
{
if (NOT_IN_PARALLEL_WORKER && pgtt_is_enabled && query->rtable != NIL)
/* Try to load pgtt if not already done. */
gtt_try_load();

if (NOT_IN_PARALLEL_WORKER && pgtt_is_enabled && query->rtable != NIL &&
GttHashTable != NULL)
{
/* replace the Oid of the template table by our new table in the rtable */
RangeTblEntry *rte = (RangeTblEntry *) linitial(query->rtable);
Expand Down Expand Up @@ -1864,16 +1900,25 @@ static void
force_pgtt_namespace (void)
{
#if PG_VERSION_NUM >= 170000
SearchPathMatcher *overridePath = GetSearchPathMatcher(CurrentMemoryContext);
SearchPathMatcher *overridePath;
#else
OverrideSearchPath *overridePath = GetOverrideSearchPath(CurrentMemoryContext);
OverrideSearchPath *overridePath;
#endif
ListCell *lc;
Oid schemaId = InvalidOid;
StringInfoData search_path;
bool found = false;
bool first = true;

if (!IsTransactionState() || GttHashTable == NULL)
return;

#if PG_VERSION_NUM >= 170000
overridePath = GetSearchPathMatcher(CurrentMemoryContext);
#else
overridePath = GetOverrideSearchPath(CurrentMemoryContext);
#endif

initStringInfo(&search_path);
/* verify that extension schema is in the path */
foreach(lc, overridePath->schemas)
Expand Down Expand Up @@ -1961,6 +2006,9 @@ gtt_create_table_as(Gtt gtt, bool skipdata)

elog(DEBUG1, "proceeding to Global Temporary Table creation.");

/* This can only be called if GTT has been properly loaded. */
Assert(GttHashTable != NULL);

connected = SPI_connect();
if (connected != SPI_OK_CONNECT)
ereport(ERROR, (errmsg("could not connect to SPI manager")));
Expand Down

0 comments on commit 513b416

Please sign in to comment.