Skip to content

Commit

Permalink
allow bypassing preflight check (#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChuckHend authored Dec 17, 2024
1 parent 4ea1991 commit 0e68cce
Show file tree
Hide file tree
Showing 8 changed files with 510 additions and 201 deletions.
659 changes: 468 additions & 191 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "pg_later"
version = "0.2.0"
version = "0.3.0"
edition = "2021"
publish = false

Expand Down
2 changes: 1 addition & 1 deletion Trunk.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ description = "Execute SQL statements now and check the results later."
homepage = "https://github.com/tembo-io/pg_later"
documentation = "https://github.com/tembo-io/pg_later"
categories = ["orchestration"]
version = "0.2.0"
version = "0.3.0"
loadable_libraries = [{ library_name = "pg_later", requires_restart = true }]

[build]
Expand Down
13 changes: 13 additions & 0 deletions sql/pg_later--0.2.0--0.3.0.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
DROP FUNCTION pglater."exec";

-- src/api.rs:26
-- pg_later::api::exec
CREATE FUNCTION "exec"(
"query" TEXT, /* &str */
"delay" bigint DEFAULT 0, /* i64 */
"validate" bool DEFAULT true /* bool */
) RETURNS bigint /* core::result::Result<i64, pgrx::spi::SpiError> */
STRICT
LANGUAGE c /* Rust */
AS 'MODULE_PATHNAME', 'exec_wrapper';
/* </end connected objects> */
14 changes: 10 additions & 4 deletions src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,17 @@ fn init() -> Result<bool, spi::Error> {

/// send a query to be executed by the next available worker
#[pg_extern]
pub fn exec(query: &str, delay: default!(i64, 0)) -> Result<i64, spi::Error> {
pub fn exec(
query: &str,
delay: default!(i64, 0),
validate: default!(bool, true),
) -> Result<i64, spi::Error> {
let prepared_query = query.replace('\'', "''").replace(';', "");
let dialect = PostgreSqlDialect {}; // Use PostgreSqlDialect for PostgreSQL
let parse_result = Parser::parse_sql(&dialect, &prepared_query);
parse_result.expect("Query parsing failed, please submit a valid query");
if validate {
let dialect = PostgreSqlDialect {};
let parse_result = Parser::parse_sql(&dialect, &prepared_query);
parse_result.expect("Query parsing failed, please submit a valid query");
}
let msg = serde_json::json!({
"query": prepared_query,
});
Expand Down
2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use pgrx::prelude::*;

pgrx::pg_module_magic!();

mod api;
Expand Down
4 changes: 2 additions & 2 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ pub fn get_pg_options() -> Result<PgConnectOptions> {
match (guc_host.as_ref(), env_socket.as_ref()) {
(Some(guc), _) => {
log!("pg-later: connecting with value from pglater.host");
let socket_conn = PostgresSocketConnection::from_unix_socket_string(&guc)
let socket_conn = PostgresSocketConnection::from_unix_socket_string(guc)
.expect("invalid value in pglater.host");
get_pgc_socket_opt(socket_conn)
}
(None, Some(env)) => {
log!("pg-later: connecting with value from env PGLATER_SOCKET_URL");
let socket_conn = PostgresSocketConnection::from_unix_socket_string(&env)
let socket_conn = PostgresSocketConnection::from_unix_socket_string(env)
.expect("invalid value in env PGLATER_SOCKET_URL");
get_pgc_socket_opt(socket_conn)
}
Expand Down
15 changes: 15 additions & 0 deletions tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,4 +113,19 @@ async fn test_lifecycle() {
.await
.expect("failed to fetch");
assert!(row.0, "table must exist");

// bypass validation -- any statement will pass through to execution
let bypass_validation_query = format!(
"SELECT pglater.exec(
query => 'CREATE INDEX ON mytable USING hnsw (embedding vector_l2_ops)',
validate => false
)",
);
println!("pglater exec: {}", bypass_validation_query);
let q2 = sqlx::query(&bypass_validation_query)
.fetch_one(&conn)
.await
.expect("failed bypass check")
.get::<i64, usize>(0);
assert!(q2 > q1, "job ids should increase");
}

0 comments on commit 0e68cce

Please sign in to comment.