Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added connect_mut for data changing SPI operations #1913

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pgrx-examples/schemas/src/lib.rs
Original file line number Diff line number Diff line change
@@ -101,7 +101,7 @@ mod tests {

#[pg_test]
fn test_my_some_schema_type() -> Result<(), spi::Error> {
Spi::connect(|mut c| {
Spi::connect_mut(|c| {
// "MySomeSchemaType" is in 'some_schema', so it needs to be discoverable
c.update("SET search_path TO some_schema,public", None, &[])?;
assert_eq!(
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/bgworker_tests.rs
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ pub extern "C" fn bgworker(arg: pg_sys::Datum) {
if arg > 0 {
BackgroundWorker::transaction(|| {
Spi::run("CREATE TABLE tests.bgworker_test (v INTEGER);")?;
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client
.update("INSERT INTO tests.bgworker_test VALUES ($1);", None, &[arg.into()])
.map(|_| ())
@@ -66,7 +66,7 @@ pub extern "C" fn bgworker_return_value(arg: pg_sys::Datum) {
};
while BackgroundWorker::wait_latch(Some(Duration::from_millis(100))) {}
BackgroundWorker::transaction(|| {
Spi::connect(|mut c| {
Spi::connect_mut(|c| {
c.update("INSERT INTO tests.bgworker_test_return VALUES ($1)", None, &[val.into()])
.map(|_| ())
})
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/guc_tests.rs
Original file line number Diff line number Diff line change
@@ -202,7 +202,7 @@ mod tests {
Spi::run("SET test.no_show TO false;").expect("SPI failed");
Spi::run("SET test.no_reset_all TO false;").expect("SPI failed");
assert_eq!(GUC_NO_RESET_ALL.get(), false);
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
let r = client.update("SHOW ALL", None, &[]).expect("SPI failed");

let mut no_reset_guc_in_show_all = false;
2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/pg_cast_tests.rs
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ mod tests {

#[pg_test]
fn test_pg_cast_assignment_type_cast() {
let _ = Spi::connect(|mut client| {
let _ = Spi::connect_mut(|client| {
client.update("CREATE TABLE test_table(value int4);", None, &[])?;
client.update("INSERT INTO test_table VALUES('{\"a\": 1}'::json->'a');", None, &[])?;

23 changes: 11 additions & 12 deletions pgrx-tests/src/tests/spi_tests.rs
Original file line number Diff line number Diff line change
@@ -165,7 +165,7 @@ mod tests {

#[pg_test]
fn test_inserting_null() -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.null_test (id uuid)", None, &[]).map(|_| ())
})?;
assert_eq!(
@@ -188,7 +188,7 @@ mod tests {

#[pg_test]
fn test_cursor() -> Result<(), spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
@@ -208,7 +208,7 @@ mod tests {

#[pg_test]
fn test_cursor_prepared_statement() -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
@@ -245,7 +245,7 @@ mod tests {
fn test_cursor_prepared_statement_panics_impl(
args: &[DatumWithOid],
) -> Result<(), pgrx::spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
@@ -264,7 +264,7 @@ mod tests {

#[pg_test]
fn test_cursor_by_name() -> Result<(), pgrx::spi::Error> {
let cursor_name = Spi::connect(|mut client| {
let cursor_name = Spi::connect_mut(|client| {
client.update("CREATE TABLE tests.cursor_table (id int)", None, &[])?;
client.update(
"INSERT INTO tests.cursor_table (id) \
@@ -318,7 +318,7 @@ mod tests {
Ok::<_, spi::Error>(())
})?;

Spi::connect(|mut client| {
Spi::connect_mut(|client| {
let res = client.update("SET TIME ZONE 'PST8PDT'", None, &[])?;

assert_eq!(Err(spi::Error::NoTupleTable), res.columns());
@@ -334,9 +334,8 @@ mod tests {

#[pg_test]
fn test_spi_non_mut() -> Result<(), pgrx::spi::Error> {
// Ensures update and cursor APIs do not need mutable reference to SpiClient
Spi::connect(|mut client| {
client.update("SELECT 1", None, &[]).expect("SPI failed");
// Ensures cursor APIs do not need mutable reference to SpiClient
Spi::connect(|client| {
let cursor = client.open_cursor("SELECT 1", &[]).detach_into_name();
client.find_cursor(&cursor).map(|_| ())
})
@@ -428,7 +427,7 @@ mod tests {

#[pg_test]
fn test_readwrite_in_select_readwrite() -> Result<(), spi::Error> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
// This is supposed to switch connection to read-write and run it there
client.update("CREATE TABLE a (id INT)", None, &[])?;
// This is supposed to run in read-write
@@ -459,7 +458,7 @@ mod tests {

#[pg_test]
fn test_spi_select_sees_update() -> spi::Result<()> {
let with_select = Spi::connect(|mut client| {
let with_select = Spi::connect_mut(|client| {
client.update("CREATE TABLE asd(id int)", None, &[])?;
client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?;
client.select("SELECT COUNT(*) FROM asd", None, &[])?.first().get_one::<i64>()
@@ -485,7 +484,7 @@ mod tests {

#[pg_test]
fn test_spi_select_sees_update_in_other_session() -> spi::Result<()> {
Spi::connect::<spi::Result<()>, _>(|mut client| {
Spi::connect_mut::<spi::Result<()>, _>(|client| {
client.update("CREATE TABLE asd(id int)", None, &[])?;
client.update("INSERT INTO asd(id) VALUES (1)", None, &[])?;
Ok(())
4 changes: 2 additions & 2 deletions pgrx-tests/src/tests/srf_tests.rs
Original file line number Diff line number Diff line change
@@ -243,7 +243,7 @@ mod tests {

#[pg_test]
fn test_srf_setof_datum_detoasting_with_borrow() {
let cnt = Spi::connect(|mut client| {
let cnt = Spi::connect_mut(|client| {
// build up a table with one large column that Postgres will be forced to TOAST
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?;

@@ -261,7 +261,7 @@ mod tests {

#[pg_test]
fn test_srf_table_datum_detoasting_with_borrow() {
let cnt = Spi::connect(|mut client| {
let cnt = Spi::connect_mut(|client| {
// build up a table with one large column that Postgres will be forced to TOAST
client.update("CREATE TABLE test_srf_datum_detoasting AS SELECT array_to_string(array_agg(g),' ') s FROM (SELECT 'a' g FROM generate_series(1, 1000)) x;", None, &[])?;

2 changes: 1 addition & 1 deletion pgrx-tests/src/tests/struct_type_tests.rs
Original file line number Diff line number Diff line change
@@ -57,7 +57,7 @@ mod tests {

#[pg_test]
fn test_complex_storage_and_retrieval() -> Result<(), pgrx::spi::Error> {
let complex = Spi::connect(|mut client| {
let complex = Spi::connect_mut(|client| {
client.update(
"CREATE TABLE complex_test AS SELECT s as id, (s || '.0, 2.0' || s)::complex as value FROM generate_series(1, 1000) s;\
SELECT value FROM complex_test ORDER BY id;", None, &[])?.first().get_one::<PgBox<Complex>>()
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ error: lifetime may not live long enough
8 | let mut res = Spi::connect(|c| {
| -- return type of closure is SpiTupleTable<'2>
| |
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`
9 | / c.open_cursor("select 'hello world' from generate_series(1, 1000)", &[])
10 | | .fetch(1000)
11 | | .unwrap()
@@ -31,7 +31,7 @@ error: lifetime may not live long enough
| -- ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
| ||
| |return type of closure is SpiTupleTable<'2>
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`

error[E0515]: cannot return value referencing temporary value
--> tests/compile-fail/escaping-spiclient-1209-cursor.rs:16:26
Original file line number Diff line number Diff line change
@@ -5,4 +5,4 @@ error: lifetime may not live long enough
| -- ^^^^^^^^^^^^^^^^^ returning this value requires that `'1` must outlive `'2`
| ||
| |return type of closure is std::result::Result<pgrx::spi::PreparedStatement<'2>, pgrx::spi::SpiError>
| has type `SpiClient<'1>`
| has type `&SpiClient<'1>`
75 changes: 57 additions & 18 deletions pgrx/src/spi.rs
Original file line number Diff line number Diff line change
@@ -21,7 +21,6 @@ mod cursor;
mod query;
mod tuple;
pub use client::SpiClient;
use client::SpiConnection;
pub use cursor::SpiCursor;
pub use query::{OwnedPreparedStatement, PreparedStatement, Query};
pub use tuple::{SpiHeapTupleData, SpiHeapTupleDataEntry, SpiTupleTable};
@@ -237,13 +236,13 @@ impl Spi {
}

pub fn get_one<A: FromDatum + IntoDatum>(query: &str) -> Result<Option<A>> {
Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_one())
Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_one())
}

pub fn get_two<A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
) -> Result<(Option<A>, Option<B>)> {
Spi::connect(|mut client| client.update(query, Some(1), &[])?.first().get_two::<A, B>())
Spi::connect_mut(|client| client.update(query, Some(1), &[])?.first().get_two::<A, B>())
}

pub fn get_three<
@@ -253,7 +252,7 @@ impl Spi {
>(
query: &str,
) -> Result<(Option<A>, Option<B>, Option<C>)> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update(query, Some(1), &[])?.first().get_three::<A, B, C>()
})
}
@@ -262,14 +261,14 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<Option<A>> {
Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_one())
Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_one())
}

pub fn get_two_with_args<'mcx, A: FromDatum + IntoDatum, B: FromDatum + IntoDatum>(
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<(Option<A>, Option<B>)> {
Spi::connect(|mut client| client.update(query, Some(1), args)?.first().get_two::<A, B>())
Spi::connect_mut(|client| client.update(query, Some(1), args)?.first().get_two::<A, B>())
}

pub fn get_three_with_args<
@@ -281,12 +280,12 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> Result<(Option<A>, Option<B>, Option<C>)> {
Spi::connect(|mut client| {
Spi::connect_mut(|client| {
client.update(query, Some(1), args)?.first().get_three::<A, B, C>()
})
}

/// just run an arbitrary SQL statement.
/// Just run an arbitrary SQL statement.
///
/// ## Safety
///
@@ -304,7 +303,7 @@ impl Spi {
query: &str,
args: &[DatumWithOid<'mcx>],
) -> std::result::Result<(), Error> {
Spi::connect(|mut client| client.update(query, None, args).map(|_| ()))
Spi::connect_mut(|client| client.update(query, None, args).map(|_| ()))
}

/// explain a query, returning its result in json form
@@ -314,7 +313,7 @@ impl Spi {

/// explain a query with args, returning its result in json form
pub fn explain_with_args<'mcx>(query: &str, args: &[DatumWithOid<'mcx>]) -> Result<Json> {
Ok(Spi::connect(|mut client| {
Ok(Spi::connect_mut(|client| {
client
.update(&format!("EXPLAIN (format json) {query}"), None, args)?
.first()
@@ -323,7 +322,7 @@ impl Spi {
.unwrap())
}

/// Execute SPI commands via the provided `SpiClient`.
/// Execute SPI read-only commands via the provided `SpiClient`.
///
/// While inside the provided closure, code executes under a short-lived "SPI Memory Context",
/// and Postgres will completely free that context when this function is finished.
@@ -360,10 +359,51 @@ impl Spi {
/// ([`pg_sys::SPI_connect()`]) **always** returns a successful response.
pub fn connect<R, F>(f: F) -> R
where
F: FnOnce(SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes:
- 'conn ~= CurrentMemoryContext after connection
- 'ret ~= SPI_palloc's context
*/
F: FnOnce(&SpiClient<'_>) -> R,
{
Self::connect_mut(|client| f(client))
}

/// Execute SPI mutating commands via the provided `SpiClient`.
///
/// While inside the provided closure, code executes under a short-lived "SPI Memory Context",
/// and Postgres will completely free that context when this function is finished.
///
/// pgrx' SPI API endeavors to return Datum values from functions like `::get_one()` that are
/// automatically copied into the into the `CurrentMemoryContext` at the time of this
/// function call.
///
/// # Examples
///
/// ```rust,no_run
/// use pgrx::prelude::*;
/// # fn foo() -> spi::Result<()> {
/// Spi::connect_mut(|client| {
/// client.update("INSERT INTO users VALUES ('Bob')", None, &[])?;
/// Ok(())
/// })
/// # }
/// ```
///
/// Note that `SpiClient` is scoped to the connection lifetime and cannot be returned. The
/// following code will not compile:
///
/// ```rust,compile_fail
/// use pgrx::prelude::*;
/// let cant_return_client = Spi::connect(|client| client);
/// ```
///
/// # Panics
///
/// This function will panic if for some reason it's unable to "connect" to Postgres' SPI
/// system. At the time of this writing, that's actually impossible as the underlying function
/// ([`pg_sys::SPI_connect()`]) **always** returns a successful response.
pub fn connect_mut<R, F>(f: F) -> R
where
F: FnOnce(&mut SpiClient<'_>) -> R, /* TODO: redesign this with 2 lifetimes:
- 'conn ~= CurrentMemoryContext after connection
- 'ret ~= SPI_palloc's context
*/
{
// connect to SPI
//
@@ -379,14 +419,13 @@ impl Spi {
// otherwise this function would need to return a `Result<R, spi::Error>` and that's a
// fucking nightmare for users to deal with. There's ample discussion around coming to
// this decision at https://github.com/pgcentralfoundation/pgrx/pull/977
let connection =
SpiConnection::connect().expect("SPI_connect indicated an unexpected failure");
let mut client = SpiClient::connect().expect("SPI_connect indicated an unexpected failure");

// run the provided closure within the memory context that SPI_connect()
// just put us un. We'll disconnect from SPI when the closure is finished.
// If there's a panic or elog(ERROR), we don't care about also disconnecting from
// SPI b/c Postgres will do that for us automatically
f(connection.client())
f(&mut client)
}

#[track_caller]
39 changes: 12 additions & 27 deletions pgrx/src/spi/client.rs
Original file line number Diff line number Diff line change
@@ -9,10 +9,18 @@ use super::query::PreparableQuery;

// TODO: should `'conn` be invariant?
pub struct SpiClient<'conn> {
__marker: PhantomData<&'conn SpiConnection>,
__marker: PhantomData<&'conn ()>,
}

impl<'conn> SpiClient<'conn> {
/// Connect to Postgres' SPI system
pub(super) fn connect() -> SpiResult<Self> {
// SPI_connect() is documented as being able to return SPI_ERROR_CONNECT, so we have to
// assume it could. The truth seems to be that it never actually does.
Spi::check_status(unsafe { pg_sys::SPI_connect() })?;
Ok(SpiClient { __marker: PhantomData })
}

/// Prepares a statement that is valid for the lifetime of the client
pub fn prepare<Q: PreparableQuery<'conn>>(
&self,
@@ -156,35 +164,12 @@ impl<'conn> SpiClient<'conn> {
}
}

/// a struct to manage our SPI connection lifetime
pub(super) struct SpiConnection(PhantomData<*mut ()>);

impl SpiConnection {
/// Connect to Postgres' SPI system
pub(super) fn connect() -> SpiResult<Self> {
// connect to SPI
//
// SPI_connect() is documented as being able to return SPI_ERROR_CONNECT, so we have to
// assume it could. The truth seems to be that it never actually does. The one user
// of SpiConnection::connect() returns `spi::Result` anyways, so it's no big deal
Spi::check_status(unsafe { pg_sys::SPI_connect() })?;
Ok(SpiConnection(PhantomData))
}
}

impl Drop for SpiConnection {
/// when SpiConnection is dropped, we make sure to disconnect from SPI
impl Drop for SpiClient<'_> {
/// When `SpiClient` is dropped, we make sure to disconnect from SPI
fn drop(&mut self) {
// best efforts to disconnect from SPI
// Best efforts to disconnect from SPI
// SPI_finish() would only complain if we hadn't previously called SPI_connect() and
// SpiConnection should prevent that from happening (assuming users don't go unsafe{})
Spi::check_status(unsafe { pg_sys::SPI_finish() }).ok();
}
}

impl SpiConnection {
/// Return a client that with a lifetime scoped to this connection.
pub(super) fn client(&self) -> SpiClient<'_> {
SpiClient { __marker: PhantomData }
}
}
6 changes: 3 additions & 3 deletions pgrx/src/spi/cursor.rs
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ type CursorName = String;
/// ```rust,no_run
/// use pgrx::prelude::*;
/// # fn foo() -> spi::Result<()> {
/// Spi::connect(|mut client| {
/// Spi::connect_mut(|client| {
/// let mut cursor = client.open_cursor("SELECT * FROM generate_series(1, 5)", &[]);
/// assert_eq!(Some(1), cursor.fetch(1)?.get_one::<i32>()?);
/// assert_eq!(Some(2), cursor.fetch(2)?.get_one::<i32>()?);
@@ -47,13 +47,13 @@ type CursorName = String;
/// ```rust,no_run
/// use pgrx::prelude::*;
/// # fn foo() -> spi::Result<()> {
/// let cursor_name = Spi::connect(|mut client| {
/// let cursor_name = Spi::connect_mut(|client| {
/// let mut cursor = client.open_cursor("SELECT * FROM generate_series(1, 5)", &[]);
/// assert_eq!(Ok(Some(1)), cursor.fetch(1)?.get_one::<i32>());
/// Ok::<_, spi::Error>(cursor.detach_into_name()) // <-- cursor gets dropped here
/// // <--- first SpiTupleTable gets freed by Spi::connect at this point
/// })?;
/// Spi::connect(|mut client| {
/// Spi::connect_mut(|client| {
/// let mut cursor = client.find_cursor(&cursor_name)?;
/// assert_eq!(Ok(Some(2)), cursor.fetch(1)?.get_one::<i32>());
/// drop(cursor); // <-- cursor gets dropped here