Skip to content

Commit

Permalink
Merge pull request #44 from rjuju/preload
Browse files Browse the repository at this point in the history
Allow preloading pgtt in session_preload_libraries
  • Loading branch information
darold authored Apr 13, 2024
2 parents 3056141 + 513b416 commit 5d8c89e
Showing 1 changed file with 103 additions and 55 deletions.
158 changes: 103 additions & 55 deletions pgtt.c
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ static void gtt_post_parse_analyze(ParseState *pstate, Query *query, struct Jumb
#else
static void gtt_post_parse_analyze(ParseState *pstate, Query *query);
#endif
static void gtt_try_load(void);
#if PG_VERSION_NUM < 160000
Oid get_extension_schema(Oid ext_oid);
#endif
Expand Down Expand Up @@ -217,7 +218,7 @@ static Oid gtt_create_table_statement(Gtt gtt);
static void gtt_create_table_as(Gtt gtt, bool skipdata);
static void gtt_unregister_global_temporary_table(Oid relid, const char *relname);
void GttHashTableDeleteAll(void);
void EnableGttManager(void);
bool EnableGttManager(void);
Gtt GetGttByName(const char *name);
static void gtt_load_global_temporary_tables(void);
static Oid create_temporary_table_internal(Oid parent_relid, bool preserved);
Expand All @@ -236,21 +237,11 @@ static void gtt_unregister_gtt_not_cached(const char *relname);
void
_PG_init(void)
{
static bool inited = false;

elog(DEBUG1, "_PG_init() : %d", inited);

/*
* Be sure we do initialization only once.
*
* If initialization fails due to, e.g., ERROR: extension "pgtt" does
* not exist, then we'll return here on the next usage.
*/
if (inited)
return;
elog(DEBUG1, "_PG_init()");

if (ParallelWorkerNumber >= 0)
return;

/*
* If we are loaded via shared_preload_libraries exit.
*/
Expand All @@ -261,31 +252,6 @@ _PG_init(void)
errhint("Use \"LOAD 'pgtt';\" in the running session instead.")));
}

if (!IsTransactionState())
{
ereport(FATAL,
(errmsg("The pgtt extension can not be loaded using session_preload_libraries."),
errhint("Use \"LOAD 'pgtt';\" in the running session instead.")));
}

if (GttHashTable == NULL)
{
/* Initialize list of Global Temporary Table */
EnableGttManager();

/*
* Load temporary table definition from pg_global_temp_tables table
* into our Hash table and pre-create the temporary tables.
*/
gtt_load_global_temporary_tables();
}

/*
* Be sure that extension schema is at end of the search path so that
* "template" tables will be found.
*/
force_pgtt_namespace();

/*
* Define (or redefine) custom GUC variables.
* No custom GUC variable at this time
Expand All @@ -303,8 +269,6 @@ _PG_init(void)
NULL,
NULL);

inited = true;

/*
* Install hooks.
*/
Expand Down Expand Up @@ -351,12 +315,14 @@ gtt_ProcessUtility(GTT_PROCESSUTILITY_PROTO)
/* Do not waste time here if the feature is not enabled for this session */
if (pgtt_is_enabled && NOT_IN_PARALLEL_WORKER)
{
/* Try to load pgtt if not already done. */
gtt_try_load();

/*
* Be sure that extension schema is at end of the search path so that
* "template" tables will be find.
*/
if (IsTransactionState())
force_pgtt_namespace();
force_pgtt_namespace();

/*
* Check if we have a CREATE GLOBAL TEMPORARY TABLE
Expand Down Expand Up @@ -415,6 +381,9 @@ gtt_check_command(GTT_PROCESSUTILITY_PROTO)

elog(DEBUG1, "gtt_check_command() on query: \"%s\"", queryString);

if (GttHashTable == NULL)
return false;

/* Intercept CREATE / DROP TABLE statements */
switch (nodeTag(parsetree))
{
Expand Down Expand Up @@ -999,6 +968,8 @@ gtt_ExecutorStart(QueryDesc *queryDesc, int eflags)
/* Do not waste time here if the feature is not enabled for this session */
if (pgtt_is_enabled && NOT_IN_PARALLEL_WORKER)
{
/* Try to load pgtt if not already done. */
gtt_try_load();

/* check if we are working on a GTT and create it if it doesn't exist */
if (queryDesc->operation == CMD_INSERT
Expand Down Expand Up @@ -1034,7 +1005,7 @@ gtt_table_exists(QueryDesc *queryDesc)
Gtt gtt;
PlannedStmt *pstmt = (PlannedStmt *) queryDesc->plannedstmt;

if (!pstmt)
if (GttHashTable == NULL || !pstmt)
return false;

/* no relation in rtable probably a function call */
Expand Down Expand Up @@ -1295,6 +1266,36 @@ gtt_unregister_gtt_not_cached(const char *relname)
#endif
}

/*
* Check if pgtt hasn't been loaded yet, and try to load it in that case.
*/
static void
gtt_try_load(void)
{
/*
* Don't try to load if the extension is disabled, if we can't do it now or
* if it's already loaded.
*/
if (!pgtt_is_enabled || !IsTransactionState() || GttHashTable != NULL)
return;

/* Initialize list of Global Temporary Table */
if (EnableGttManager())
{
/*
* Load temporary table definition from pg_global_temp_tables table
* into our Hash table and pre-create the temporary tables.
*/
gtt_load_global_temporary_tables();

/*
* Be sure that extension schema is at end of the search path so that
* "template" tables will be found.
*/
force_pgtt_namespace();
}
}

#if PG_VERSION_NUM < 160000
/*
* From src/backend/commands/extension.c
Expand Down Expand Up @@ -1351,10 +1352,30 @@ get_extension_schema(Oid ext_oid)
* EnableGttManager
* Enables the GTT management cache at backend startup.
*/
void
bool
EnableGttManager(void)
{
Oid extOid = get_extension_oid("pgtt", false);
Oid extOid = get_extension_oid("pgtt", true);
RangeVar *rv;
char *nspname;

if (!OidIsValid(extOid))
return false;

pgtt_namespace_oid = get_extension_schema(extOid);
if (!OidIsValid(pgtt_namespace_oid))
elog(ERROR, "namespace %d can not be found.", pgtt_namespace_oid);

/*
* Check if the GTT relation also exist. We might be in the middle of the
* extension creation, where the line in pg_extension exists but not the
* rest of SQL objects.
*/
nspname = get_namespace_name(pgtt_namespace_oid);
rv = makeRangeVar(nspname, CATALOG_GLOBAL_TEMP_REL, -1);

if (!OidIsValid(RangeVarGetRelid(rv, AccessShareLock, true)))
return false;

if (GttHashTable == NULL)
{
Expand All @@ -1378,11 +1399,13 @@ EnableGttManager(void)
elog(DEBUG1, "GTT cache initialized.");
}

/* Set the OID and name of the extension schema, all objects will be created in this schema */
pgtt_namespace_oid = get_extension_schema(extOid);
if (!OidIsValid(pgtt_namespace_oid))
elog(ERROR, "namespace %d can not be found.", pgtt_namespace_oid);
strcpy(pgtt_namespace_name, get_namespace_name(pgtt_namespace_oid));
/*
* Set the OID and name of the extension schema, all objects will be
* created in this schema.
*/
strcpy(pgtt_namespace_name, nspname);

return true;
}

/*
Expand All @@ -1408,18 +1431,23 @@ GttHashTableDeleteAll(void)
/* Restart the iteration in case that led to other drops */
hash_seq_term(&status);
hash_seq_init(&status, GttHashTable);
}
}
}

/*
* GetGttByName
* Returns a Gtt given a table name, or NULL if name is not found.
*
* Caller should have made sure that GTT has been properly loaded before
* calling this function.
*/
Gtt
GetGttByName(const char *name)
{
Gtt gtt;

Assert(GttHashTable != NULL);

if (PointerIsValid(name))
GttHashTableLookup(name, gtt);

Expand Down Expand Up @@ -1450,10 +1478,12 @@ gtt_load_global_temporary_tables(void)

/* Set and open the GTT definition storage relation */
rv = makeRangeVar(pgtt_namespace_name, CATALOG_GLOBAL_TEMP_REL, -1);

/* Open the CATALOG_GLOBAL_TEMP_REL table. We don't want to allow
* writable accesses by other session during import. */
PushActiveSnapshot(GetTransactionSnapshot());
snapshot = GetActiveSnapshot();
//snapshot = GetTransactionSnapshot();

#if (PG_VERSION_NUM >= 120000)
rel = table_openrv(rv, AccessShareLock);
scan = table_beginscan(rel, snapshot, 0, (ScanKey) NULL);
Expand Down Expand Up @@ -1489,6 +1519,8 @@ gtt_load_global_temporary_tables(void)
heap_endscan(scan);
heap_close(rel, AccessShareLock);
#endif

PopActiveSnapshot();
}

static Oid
Expand Down Expand Up @@ -1733,7 +1765,11 @@ gtt_post_parse_analyze(ParseState *pstate, Query *query, struct JumbleState * js
gtt_post_parse_analyze(ParseState *pstate, Query *query)
#endif
{
if (NOT_IN_PARALLEL_WORKER && pgtt_is_enabled && query->rtable != NIL)
/* Try to load pgtt if not already done. */
gtt_try_load();

if (NOT_IN_PARALLEL_WORKER && pgtt_is_enabled && query->rtable != NIL &&
GttHashTable != NULL)
{
/* replace the Oid of the template table by our new table in the rtable */
RangeTblEntry *rte = (RangeTblEntry *) linitial(query->rtable);
Expand Down Expand Up @@ -1864,16 +1900,25 @@ static void
force_pgtt_namespace (void)
{
#if PG_VERSION_NUM >= 170000
SearchPathMatcher *overridePath = GetSearchPathMatcher(CurrentMemoryContext);
SearchPathMatcher *overridePath;
#else
OverrideSearchPath *overridePath = GetOverrideSearchPath(CurrentMemoryContext);
OverrideSearchPath *overridePath;
#endif
ListCell *lc;
Oid schemaId = InvalidOid;
StringInfoData search_path;
bool found = false;
bool first = true;

if (!IsTransactionState() || GttHashTable == NULL)
return;

#if PG_VERSION_NUM >= 170000
overridePath = GetSearchPathMatcher(CurrentMemoryContext);
#else
overridePath = GetOverrideSearchPath(CurrentMemoryContext);
#endif

initStringInfo(&search_path);
/* verify that extension schema is in the path */
foreach(lc, overridePath->schemas)
Expand Down Expand Up @@ -1961,6 +2006,9 @@ gtt_create_table_as(Gtt gtt, bool skipdata)

elog(DEBUG1, "proceeding to Global Temporary Table creation.");

/* This can only be called if GTT has been properly loaded. */
Assert(GttHashTable != NULL);

connected = SPI_connect();
if (connected != SPI_OK_CONNECT)
ereport(ERROR, (errmsg("could not connect to SPI manager")));
Expand Down

0 comments on commit 5d8c89e

Please sign in to comment.