diff --git a/sqlx-core/src/any/connection/backend.rs b/sqlx-core/src/any/connection/backend.rs index 6c84c1d8ce..e59b345ed9 100644 --- a/sqlx-core/src/any/connection/backend.rs +++ b/sqlx-core/src/any/connection/backend.rs @@ -1,9 +1,9 @@ use crate::any::{Any, AnyArguments, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo}; use crate::describe::Describe; +use crate::sql_str::SqlStr; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; -use std::borrow::Cow; use std::fmt::Debug; pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { @@ -33,7 +33,7 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { /// /// If we are already inside a transaction and `statement.is_some()`, then /// `Error::InvalidSavePoint` is returned without running any statements. - fn begin(&mut self, statement: Option>) -> BoxFuture<'_, crate::Result<()>>; + fn begin(&mut self, statement: Option) -> BoxFuture<'_, crate::Result<()>>; fn commit(&mut self) -> BoxFuture<'_, crate::Result<()>>; @@ -96,23 +96,23 @@ pub trait AnyConnectionBackend: std::any::Any + Debug + Send + 'static { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, crate::Result>>; fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, crate::Result>>; fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, crate::Result>>; + ) -> BoxFuture<'c, crate::Result>; - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, crate::Result>>; + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, crate::Result>>; } diff --git a/sqlx-core/src/any/connection/executor.rs b/sqlx-core/src/any/connection/executor.rs index ccf6dd7933..a7c2cd33a0 100644 --- a/sqlx-core/src/any/connection/executor.rs +++ b/sqlx-core/src/any/connection/executor.rs @@ -2,6 +2,7 @@ use crate::any::{Any, AnyConnection, AnyQueryResult, AnyRow, AnyStatement, AnyTy use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; +use crate::sql_str::SqlStr; use either::Either; use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; @@ -23,8 +24,8 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { Ok(arguments) => arguments, Err(error) => return stream::once(future::ready(Err(error))).boxed(), }; - self.backend - .fetch_many(query.sql(), query.persistent(), arguments) + let persistent = query.persistent(); + self.backend.fetch_many(query.sql(), persistent, arguments) } fn fetch_optional<'e, 'q: 'e, E>( @@ -39,25 +40,23 @@ impl<'c> Executor<'c> for &'c mut AnyConnection { Ok(arguments) => arguments, Err(error) => return future::ready(Err(error)).boxed(), }; + let persistent = query.persistent(); self.backend - .fetch_optional(query.sql(), query.persistent(), arguments) + .fetch_optional(query.sql(), persistent, arguments) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: SqlStr, parameters: &[AnyTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { self.backend.prepare_with(sql, parameters) } - fn describe<'e, 'q: 'e>( - self, - sql: &'q str, - ) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { diff --git a/sqlx-core/src/any/connection/mod.rs b/sqlx-core/src/any/connection/mod.rs index af4de060fc..c89dd108db 100644 --- a/sqlx-core/src/any/connection/mod.rs +++ b/sqlx-core/src/any/connection/mod.rs @@ -1,5 +1,4 @@ use futures_core::future::BoxFuture; -use std::borrow::Cow; use std::future::Future; use crate::any::{Any, AnyConnectOptions}; @@ -7,6 +6,7 @@ use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::database::Database; +use crate::sql_str::SqlSafeStr; pub use backend::AnyConnectionBackend; use crate::transaction::Transaction; @@ -96,12 +96,12 @@ impl Connection for AnyConnection { fn begin_with( &mut self, - statement: impl Into>, + statement: impl SqlSafeStr, ) -> impl Future, Error>> + Send + '_ where Self: Sized, { - Transaction::begin(self, Some(statement.into())) + Transaction::begin(self, Some(statement.into_sql_str())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-core/src/any/database.rs b/sqlx-core/src/any/database.rs index 9c3f15bb1f..6e8343e928 100644 --- a/sqlx-core/src/any/database.rs +++ b/sqlx-core/src/any/database.rs @@ -28,7 +28,7 @@ impl Database for Any { type Arguments<'q> = AnyArguments<'q>; type ArgumentBuffer<'q> = AnyArgumentBuffer<'q>; - type Statement<'q> = AnyStatement<'q>; + type Statement = AnyStatement; const NAME: &'static str = "Any"; diff --git a/sqlx-core/src/any/statement.rs b/sqlx-core/src/any/statement.rs index 6d513e9a06..98b28c2c32 100644 --- a/sqlx-core/src/any/statement.rs +++ b/sqlx-core/src/any/statement.rs @@ -3,15 +3,16 @@ use crate::column::ColumnIndex; use crate::database::Database; use crate::error::Error; use crate::ext::ustr::UStr; +use crate::sql_str::SqlStr; use crate::statement::Statement; use crate::HashMap; use either::Either; -use std::borrow::Cow; use std::sync::Arc; -pub struct AnyStatement<'q> { +#[derive(Clone)] +pub struct AnyStatement { #[doc(hidden)] - pub sql: Cow<'q, str>, + pub sql: SqlStr, #[doc(hidden)] pub parameters: Option, usize>>, #[doc(hidden)] @@ -20,19 +21,14 @@ pub struct AnyStatement<'q> { pub columns: Vec, } -impl<'q> Statement<'q> for AnyStatement<'q> { +impl Statement for AnyStatement { type Database = Any; - fn to_owned(&self) -> AnyStatement<'static> { - AnyStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), - column_names: self.column_names.clone(), - parameters: self.parameters.clone(), - columns: self.columns.clone(), - } + fn into_sql(self) -> SqlStr { + self.sql } - fn sql(&self) -> &str { + fn sql(&self) -> &SqlStr { &self.sql } @@ -51,8 +47,8 @@ impl<'q> Statement<'q> for AnyStatement<'q> { impl_statement_query!(AnyArguments<'_>); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &AnyStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &AnyStatement) -> Result { statement .column_names .get(*self) @@ -61,15 +57,14 @@ impl ColumnIndex> for &'_ str { } } -impl<'q> AnyStatement<'q> { +impl AnyStatement { #[doc(hidden)] pub fn try_from_statement( - query: &'q str, - statement: &S, + statement: S, column_names: Arc>, ) -> crate::Result where - S: Statement<'q>, + S: Statement, AnyTypeInfo: for<'a> TryFrom<&'a ::TypeInfo, Error = Error>, AnyColumn: for<'a> TryFrom<&'a ::Column, Error = Error>, { @@ -91,7 +86,7 @@ impl<'q> AnyStatement<'q> { .collect::, _>>()?; Ok(Self { - sql: query.into(), + sql: statement.into_sql(), columns, column_names, parameters, diff --git a/sqlx-core/src/any/transaction.rs b/sqlx-core/src/any/transaction.rs index a90c5e7c08..c803e97588 100644 --- a/sqlx-core/src/any/transaction.rs +++ b/sqlx-core/src/any/transaction.rs @@ -1,9 +1,9 @@ -use std::borrow::Cow; use std::future::Future; use crate::any::{Any, AnyConnection}; use crate::database::Database; use crate::error::Error; +use crate::sql_str::SqlStr; use crate::transaction::TransactionManager; pub struct AnyTransactionManager; @@ -11,10 +11,10 @@ pub struct AnyTransactionManager; impl TransactionManager for AnyTransactionManager { type Database = Any; - fn begin<'conn>( - conn: &'conn mut AnyConnection, - statement: Option>, - ) -> impl Future> + Send + 'conn { + fn begin( + conn: &mut AnyConnection, + statement: Option, + ) -> impl Future> + Send + '_ { conn.backend.begin(statement) } diff --git a/sqlx-core/src/column.rs b/sqlx-core/src/column.rs index fddc048c4b..132e7b0346 100644 --- a/sqlx-core/src/column.rs +++ b/sqlx-core/src/column.rs @@ -125,8 +125,8 @@ macro_rules! impl_column_index_for_row { #[macro_export] macro_rules! impl_column_index_for_statement { ($S:ident) => { - impl $crate::column::ColumnIndex<$S<'_>> for usize { - fn index(&self, statement: &$S<'_>) -> Result { + impl $crate::column::ColumnIndex<$S> for usize { + fn index(&self, statement: &$S) -> Result { let len = $crate::statement::Statement::columns(statement).len(); if *self >= len { diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index 25a211989c..fb698e91aa 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -1,10 +1,10 @@ use crate::database::{Database, HasStatementCache}; use crate::error::Error; +use crate::sql_str::SqlSafeStr; use crate::transaction::{Transaction, TransactionManager}; use futures_core::future::BoxFuture; use log::LevelFilter; -use std::borrow::Cow; use std::fmt::Debug; use std::future::Future; use std::str::FromStr; @@ -59,12 +59,12 @@ pub trait Connection: Send { /// `statement` does not put the connection into a transaction. fn begin_with( &mut self, - statement: impl Into>, + statement: impl SqlSafeStr, ) -> impl Future, Error>> + Send + '_ where Self: Sized, { - Transaction::begin(self, Some(statement.into())) + Transaction::begin(self, Some(statement.into_sql_str())) } /// Returns `true` if the connection is currently in a transaction. diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index e44c3d88ac..02d7a1214e 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -101,7 +101,7 @@ pub trait Database: 'static + Sized + Send + Debug { type ArgumentBuffer<'q>; /// The concrete `Statement` implementation for this database. - type Statement<'q>: Statement<'q, Database = Self>; + type Statement: Statement; /// The display name for this database driver. const NAME: &'static str; diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index f1a6ff4ba8..ab9737c9cd 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -1,6 +1,7 @@ use crate::database::Database; use crate::describe::Describe; use crate::error::{BoxDynError, Error}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use either::Either; use futures_core::future::BoxFuture; @@ -148,10 +149,10 @@ pub trait Executor<'c>: Send + Debug + Sized { /// This explicit API is provided to allow access to the statement metadata available after /// it prepared but before the first row is returned. #[inline] - fn prepare<'e, 'q: 'e>( + fn prepare<'e>( self, - query: &'q str, - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> + query: SqlStr, + ) -> BoxFuture<'e, Result<::Statement, Error>> where 'c: 'e, { @@ -163,11 +164,11 @@ pub trait Executor<'c>: Send + Debug + Sized { /// /// Only some database drivers (PostgreSQL, MSSQL) can take advantage of /// this extra information to influence parameter type inference. - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: SqlStr, parameters: &'e [::TypeInfo], - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> + ) -> BoxFuture<'e, Result<::Statement, Error>> where 'c: 'e; @@ -177,10 +178,7 @@ pub trait Executor<'c>: Send + Debug + Sized { /// This is used by compile-time verification in the query macros to /// power their type inference. #[doc(hidden)] - fn describe<'e, 'q: 'e>( - self, - sql: &'q str, - ) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e; } @@ -194,10 +192,10 @@ pub trait Executor<'c>: Send + Debug + Sized { /// pub trait Execute<'q, DB: Database>: Send + Sized { /// Gets the SQL that will be executed. - fn sql(&self) -> &'q str; + fn sql(self) -> SqlStr; /// Gets the previously cached statement, if available. - fn statement(&self) -> Option<&DB::Statement<'q>>; + fn statement(&self) -> Option<&DB::Statement>; /// Returns the arguments to be bound against the query string. /// @@ -212,16 +210,17 @@ pub trait Execute<'q, DB: Database>: Send + Sized { fn persistent(&self) -> bool; } -// NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more -// involved to write `conn.execute(format!("SELECT {val}"))` -impl<'q, DB: Database> Execute<'q, DB> for &'q str { +impl<'q, DB: Database, T> Execute<'q, DB> for T +where + T: SqlSafeStr + Send, +{ #[inline] - fn sql(&self) -> &'q str { - self + fn sql(self) -> SqlStr { + self.into_sql_str() } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { None } @@ -236,14 +235,17 @@ impl<'q, DB: Database> Execute<'q, DB> for &'q str { } } -impl<'q, DB: Database> Execute<'q, DB> for (&'q str, Option<::Arguments<'q>>) { +impl<'q, DB: Database, T> Execute<'q, DB> for (T, Option<::Arguments<'q>>) +where + T: SqlSafeStr + Send, +{ #[inline] - fn sql(&self) -> &'q str { - self.0 + fn sql(self) -> SqlStr { + self.0.into_sql_str() } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { None } diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 09f2900ba8..494c41e9bf 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -72,6 +72,7 @@ pub mod net; pub mod query_as; pub mod query_builder; pub mod query_scalar; +pub mod sql_str; pub mod raw_sql; pub mod row; diff --git a/sqlx-core/src/logger.rs b/sqlx-core/src/logger.rs index 18d5843d38..6114854396 100644 --- a/sqlx-core/src/logger.rs +++ b/sqlx-core/src/logger.rs @@ -1,4 +1,4 @@ -use crate::connection::LogSettings; +use crate::{connection::LogSettings, sql_str::SqlStr}; use std::time::Instant; // Yes these look silly. `tracing` doesn't currently support dynamic levels @@ -60,16 +60,16 @@ pub(crate) fn private_level_filter_to_trace_level( private_level_filter_to_levels(filter).map(|(level, _)| level) } -pub struct QueryLogger<'q> { - sql: &'q str, +pub struct QueryLogger { + sql: SqlStr, rows_returned: u64, rows_affected: u64, start: Instant, settings: LogSettings, } -impl<'q> QueryLogger<'q> { - pub fn new(sql: &'q str, settings: LogSettings) -> Self { +impl QueryLogger { + pub fn new(sql: SqlStr, settings: LogSettings) -> Self { Self { sql, rows_returned: 0, @@ -104,19 +104,11 @@ impl<'q> QueryLogger<'q> { let log_is_enabled = log::log_enabled!(target: "sqlx::query", log_level) || private_tracing_dynamic_enabled!(target: "sqlx::query", tracing_level); if log_is_enabled { - let mut summary = parse_query_summary(self.sql); + let mut summary = parse_query_summary(self.sql.as_str()); - let sql = if summary != self.sql { + let sql = if summary != self.sql.as_str() { summary.push_str(" …"); - format!( - "\n\n{}\n", - self.sql /* - sqlformat::format( - self.sql, - &sqlformat::QueryParams::None, - sqlformat::FormatOptions::default() - )*/ - ) + format!("\n\n{}\n", self.sql.as_str()) } else { String::new() }; @@ -158,7 +150,7 @@ impl<'q> QueryLogger<'q> { } } -impl Drop for QueryLogger<'_> { +impl Drop for QueryLogger { fn drop(&mut self) { self.finish(); } diff --git a/sqlx-core/src/migrate/migration.rs b/sqlx-core/src/migrate/migration.rs index 1f1175ce58..79721d244d 100644 --- a/sqlx-core/src/migrate/migration.rs +++ b/sqlx-core/src/migrate/migration.rs @@ -1,6 +1,8 @@ use sha2::{Digest, Sha384}; use std::borrow::Cow; +use crate::sql_str::SqlStr; + use super::MigrationType; #[derive(Debug, Clone)] @@ -8,7 +10,7 @@ pub struct Migration { pub version: i64, pub description: Cow<'static, str>, pub migration_type: MigrationType, - pub sql: Cow<'static, str>, + pub sql: SqlStr, pub checksum: Cow<'static, [u8]>, pub no_tx: bool, } @@ -18,10 +20,10 @@ impl Migration { version: i64, description: Cow<'static, str>, migration_type: MigrationType, - sql: Cow<'static, str>, + sql: SqlStr, no_tx: bool, ) -> Self { - let checksum = checksum(&sql); + let checksum = checksum(sql.as_str()); Self::with_checksum( version, @@ -37,7 +39,7 @@ impl Migration { version: i64, description: Cow<'static, str>, migration_type: MigrationType, - sql: Cow<'static, str>, + sql: SqlStr, checksum: Cow<'static, [u8]>, no_tx: bool, ) -> Self { diff --git a/sqlx-core/src/migrate/source.rs b/sqlx-core/src/migrate/source.rs index 9c2ef7719b..4648e53f1e 100644 --- a/sqlx-core/src/migrate/source.rs +++ b/sqlx-core/src/migrate/source.rs @@ -1,5 +1,6 @@ use crate::error::BoxDynError; use crate::migrate::{migration, Migration, MigrationType}; +use crate::sql_str::{AssertSqlSafe, SqlSafeStr}; use futures_core::future::BoxFuture; use std::borrow::Cow; @@ -239,7 +240,7 @@ pub fn resolve_blocking_with_config( version, Cow::Owned(description), migration_type, - Cow::Owned(sql), + AssertSqlSafe(sql).into_sql_str(), checksum.into(), no_tx, ), diff --git a/sqlx-core/src/pool/executor.rs b/sqlx-core/src/pool/executor.rs index 0eda818a5b..c168a70f51 100644 --- a/sqlx-core/src/pool/executor.rs +++ b/sqlx-core/src/pool/executor.rs @@ -8,8 +8,9 @@ use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::pool::Pool; +use crate::sql_str::SqlStr; -impl Executor<'_> for &'_ Pool +impl<'p, DB: Database> Executor<'p> for &'_ Pool where for<'c> &'c mut DB::Connection: Executor<'c, Database = DB>, { @@ -48,21 +49,21 @@ where Box::pin(async move { pool.acquire().await?.fetch_optional(query).await }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: SqlStr, parameters: &'e [::TypeInfo], - ) -> BoxFuture<'e, Result<::Statement<'q>, Error>> { + ) -> BoxFuture<'e, Result<::Statement, Error>> + where + 'p: 'e, + { let pool = self.clone(); Box::pin(async move { pool.acquire().await?.prepare_with(sql, parameters).await }) } #[doc(hidden)] - fn describe<'e, 'q: 'e>( - self, - sql: &'q str, - ) -> BoxFuture<'e, Result, Error>> { + fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result, Error>> { let pool = self.clone(); Box::pin(async move { pool.acquire().await?.describe(sql).await }) diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 07d592c824..f11ff1d76a 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -54,7 +54,6 @@ //! [`Pool::acquire`] or //! [`Pool::begin`]. -use std::borrow::Cow; use std::fmt; use std::future::Future; use std::pin::{pin, Pin}; @@ -69,6 +68,7 @@ use futures_util::FutureExt; use crate::connection::Connection; use crate::database::Database; use crate::error::Error; +use crate::sql_str::SqlSafeStr; use crate::transaction::Transaction; pub use self::connection::PoolConnection; @@ -390,11 +390,11 @@ impl Pool { /// Retrieves a connection and immediately begins a new transaction using `statement`. pub async fn begin_with( &self, - statement: impl Into>, + statement: impl SqlSafeStr, ) -> Result, Error> { Transaction::begin( MaybePoolConnection::PoolConnection(self.acquire().await?), - Some(statement.into()), + Some(statement.into_sql_str()), ) .await } @@ -403,12 +403,12 @@ impl Pool { /// transaction using `statement`. pub async fn try_begin_with( &self, - statement: impl Into>, + statement: impl SqlSafeStr, ) -> Result>, Error> { match self.try_acquire() { Some(conn) => Transaction::begin( MaybePoolConnection::PoolConnection(conn), - Some(statement.into()), + Some(statement.into_sql_str()), ) .await .map(Some), diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index 0e45643f2e..97c166116a 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -9,13 +9,14 @@ use crate::database::{Database, HasStatementCache}; use crate::encode::Encode; use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::statement::Statement; use crate::types::Type; /// A single SQL query as a prepared statement. Returned by [`query()`]. #[must_use = "query must be executed to affect database"] pub struct Query<'q, DB: Database, A> { - pub(crate) statement: Either<&'q str, &'q DB::Statement<'q>>, + pub(crate) statement: Either, pub(crate) arguments: Option>, pub(crate) database: PhantomData, pub(crate) persistent: bool, @@ -44,14 +45,14 @@ where A: Send + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { match self.statement { - Either::Right(statement) => statement.sql(), + Either::Right(statement) => statement.sql().clone(), Either::Left(sql) => sql, } } - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { match self.statement { Either::Right(statement) => Some(statement), Either::Left(_) => None, @@ -303,12 +304,12 @@ where A: IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -499,9 +500,9 @@ where } /// Execute a single SQL query as a prepared statement (explicitly created). -pub fn query_statement<'q, DB>( - statement: &'q DB::Statement<'q>, -) -> Query<'q, DB, ::Arguments<'q>> +pub fn query_statement( + statement: &DB::Statement, +) -> Query<'_, DB, ::Arguments<'_>> where DB: Database, { @@ -515,7 +516,7 @@ where /// Execute a single SQL query as a prepared statement (explicitly created), with the given arguments. pub fn query_statement_with<'q, DB, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> Query<'q, DB, A> where @@ -559,7 +560,7 @@ where /// let query = format!("SELECT * FROM articles WHERE content LIKE '%{user_input}%'"); /// // where `conn` is `PgConnection` or `MySqlConnection` /// // or some other type that implements `Executor`. -/// let results = sqlx::query(&query).fetch_all(&mut conn).await?; +/// let results = sqlx::query(sqlx::AssertSqlSafe(query)).fetch_all(&mut conn).await?; /// # Ok(()) /// # } /// ``` @@ -654,14 +655,14 @@ where /// /// As an additional benefit, query parameters are usually sent in a compact binary encoding instead of a human-readable /// text encoding, which saves bandwidth. -pub fn query(sql: &str) -> Query<'_, DB, ::Arguments<'_>> +pub fn query<'a, DB>(sql: impl SqlSafeStr) -> Query<'a, DB, ::Arguments<'a>> where DB: Database, { Query { database: PhantomData, arguments: Some(Ok(Default::default())), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } @@ -669,7 +670,7 @@ where /// Execute a SQL query as a prepared statement (transparently cached), with the given arguments. /// /// See [`query()`][query] for details, such as supported syntax. -pub fn query_with<'q, DB, A>(sql: &'q str, arguments: A) -> Query<'q, DB, A> +pub fn query_with<'q, DB, A>(sql: impl SqlSafeStr, arguments: A) -> Query<'q, DB, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -679,7 +680,7 @@ where /// Same as [`query_with`] but is initialized with a Result of arguments instead pub fn query_with_result<'q, DB, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> Query<'q, DB, A> where @@ -689,7 +690,7 @@ where Query { database: PhantomData, arguments: Some(arguments), - statement: Either::Left(sql), + statement: Either::Left(sql.into_sql_str()), persistent: true, } } diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 7cefb02975..e58a3f0f8b 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -11,6 +11,7 @@ use crate::error::{BoxDynError, Error}; use crate::executor::{Execute, Executor}; use crate::from_row::FromRow; use crate::query::{query, query_statement, query_statement_with, query_with_result, Query}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::types::Type; /// A single SQL query as a prepared statement, mapping results using [`FromRow`]. @@ -27,12 +28,12 @@ where A: 'q + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } #[inline] - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -337,7 +338,9 @@ where /// /// ``` #[inline] -pub fn query_as<'q, DB, O>(sql: &'q str) -> QueryAs<'q, DB, O, ::Arguments<'q>> +pub fn query_as<'q, DB, O>( + sql: impl SqlSafeStr, +) -> QueryAs<'q, DB, O, ::Arguments<'q>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, @@ -355,7 +358,7 @@ where /// /// For details about type mapping from [`FromRow`], see [`query_as()`]. #[inline] -pub fn query_as_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryAs<'q, DB, O, A> +pub fn query_as_with<'q, DB, O, A>(sql: impl SqlSafeStr, arguments: A) -> QueryAs<'q, DB, O, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -367,7 +370,7 @@ where /// Same as [`query_as_with`] but takes arguments as a Result #[inline] pub fn query_as_with_result<'q, DB, O, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> QueryAs<'q, DB, O, A> where @@ -382,9 +385,9 @@ where } // Make a SQL query from a statement, that is mapped to a concrete type. -pub fn query_statement_as<'q, DB, O>( - statement: &'q DB::Statement<'q>, -) -> QueryAs<'q, DB, O, ::Arguments<'q>> +pub fn query_statement_as( + statement: &DB::Statement, +) -> QueryAs<'_, DB, O, ::Arguments<'_>> where DB: Database, O: for<'r> FromRow<'r, DB::Row>, @@ -397,7 +400,7 @@ where // Make a SQL query from a statement, with the given arguments, that is mapped to a concrete type. pub fn query_statement_as_with<'q, DB, O, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> QueryAs<'q, DB, O, A> where diff --git a/sqlx-core/src/query_builder.rs b/sqlx-core/src/query_builder.rs index b14d19adb2..7dff67831c 100644 --- a/sqlx-core/src/query_builder.rs +++ b/sqlx-core/src/query_builder.rs @@ -3,6 +3,7 @@ use std::fmt::Display; use std::fmt::Write; use std::marker::PhantomData; +use std::sync::Arc; use crate::arguments::{Arguments, IntoArguments}; use crate::database::Database; @@ -11,6 +12,9 @@ use crate::from_row::FromRow; use crate::query::Query; use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; +use crate::sql_str::AssertSqlSafe; +use crate::sql_str::SqlSafeStr; +use crate::sql_str::SqlStr; use crate::types::Type; use crate::Either; @@ -25,7 +29,7 @@ pub struct QueryBuilder<'args, DB> where DB: Database, { - query: String, + query: Arc, init_len: usize, arguments: Option<::Arguments<'args>>, } @@ -34,12 +38,14 @@ impl Default for QueryBuilder<'_, DB> { fn default() -> Self { QueryBuilder { init_len: 0, - query: String::default(), + query: Default::default(), arguments: Some(Default::default()), } } } +const ERROR: &str = "BUG: query must not be shared at this point in time"; + impl<'args, DB: Database> QueryBuilder<'args, DB> where DB: Database, @@ -55,7 +61,7 @@ where QueryBuilder { init_len: init.len(), - query: init, + query: init.into(), arguments: Some(Default::default()), } } @@ -73,7 +79,7 @@ where QueryBuilder { init_len: init.len(), - query: init, + query: init.into(), arguments: Some(arguments.into_arguments()), } } @@ -115,8 +121,9 @@ where /// e.g. check that strings aren't too long, numbers are within expected ranges, etc. pub fn push(&mut self, sql: impl Display) -> &mut Self { self.sanity_check(); + let query: &mut String = Arc::get_mut(&mut self.query).expect(ERROR); - write!(self.query, "{sql}").expect("error formatting `sql`"); + write!(query, "{sql}").expect("error formatting `sql`"); self } @@ -157,8 +164,9 @@ where .expect("BUG: Arguments taken already"); arguments.add(value).expect("Failed to add argument"); + let query: &mut String = Arc::get_mut(&mut self.query).expect(ERROR); arguments - .format_placeholder(&mut self.query) + .format_placeholder(query) .expect("error in format_placeholder"); self @@ -453,7 +461,7 @@ where self.sanity_check(); Query { - statement: Either::Left(&self.query), + statement: Either::Left(self.sql()), arguments: self.arguments.take().map(Ok), database: PhantomData, persistent: true, @@ -510,20 +518,28 @@ where /// The query is truncated to the initial fragment provided to [`new()`][Self::new] and /// the bind arguments are reset. pub fn reset(&mut self) -> &mut Self { - self.query.truncate(self.init_len); + // Someone can hold onto a clone of `self.query`, to avoid panicking here we should just + // allocate a new `String`. + let query: &mut String = Arc::make_mut(&mut self.query); + query.truncate(self.init_len); self.arguments = Some(Default::default()); self } /// Get the current build SQL; **note**: may not be syntactically correct. - pub fn sql(&self) -> &str { - &self.query + pub fn sql(&self) -> SqlStr { + AssertSqlSafe(self.query.clone()).into_sql_str() + } + + /// Deconstruct this `QueryBuilder`, returning the built SQL. May not be syntactically correct. + pub fn into_string(self) -> String { + Arc::unwrap_or_clone(self.query) } /// Deconstruct this `QueryBuilder`, returning the built SQL. May not be syntactically correct. - pub fn into_sql(self) -> String { - self.query + pub fn into_sql(self) -> SqlStr { + AssertSqlSafe(self.query).into_sql_str() } } diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index f3fcfb403a..1059463874 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -11,6 +11,7 @@ use crate::from_row::FromRow; use crate::query_as::{ query_as, query_as_with_result, query_statement_as, query_statement_as_with, QueryAs, }; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::types::Type; /// A single SQL query as a prepared statement which extracts only the first column of each row. @@ -25,11 +26,11 @@ where A: 'q + IntoArguments<'q, DB>, { #[inline] - fn sql(&self) -> &'q str { + fn sql(self) -> SqlStr { self.inner.sql() } - fn statement(&self) -> Option<&DB::Statement<'q>> { + fn statement(&self) -> Option<&DB::Statement> { self.inner.statement() } @@ -319,7 +320,7 @@ where /// ``` #[inline] pub fn query_scalar<'q, DB, O>( - sql: &'q str, + sql: impl SqlSafeStr, ) -> QueryScalar<'q, DB, O, ::Arguments<'q>> where DB: Database, @@ -337,7 +338,10 @@ where /// /// For details about prepared statements and allowed SQL syntax, see [`query()`][crate::query::query]. #[inline] -pub fn query_scalar_with<'q, DB, O, A>(sql: &'q str, arguments: A) -> QueryScalar<'q, DB, O, A> +pub fn query_scalar_with<'q, DB, O, A>( + sql: impl SqlSafeStr, + arguments: A, +) -> QueryScalar<'q, DB, O, A> where DB: Database, A: IntoArguments<'q, DB>, @@ -349,7 +353,7 @@ where /// Same as [`query_scalar_with`] but takes arguments as Result #[inline] pub fn query_scalar_with_result<'q, DB, O, A>( - sql: &'q str, + sql: impl SqlSafeStr, arguments: Result, ) -> QueryScalar<'q, DB, O, A> where @@ -363,9 +367,9 @@ where } // Make a SQL query from a statement, that is mapped to a concrete value. -pub fn query_statement_scalar<'q, DB, O>( - statement: &'q DB::Statement<'q>, -) -> QueryScalar<'q, DB, O, ::Arguments<'q>> +pub fn query_statement_scalar( + statement: &DB::Statement, +) -> QueryScalar<'_, DB, O, ::Arguments<'_>> where DB: Database, (O,): for<'r> FromRow<'r, DB::Row>, @@ -377,7 +381,7 @@ where // Make a SQL query from a statement, with the given arguments, that is mapped to a concrete value. pub fn query_statement_scalar_with<'q, DB, O, A>( - statement: &'q DB::Statement<'q>, + statement: &'q DB::Statement, arguments: A, ) -> QueryScalar<'q, DB, O, A> where diff --git a/sqlx-core/src/raw_sql.rs b/sqlx-core/src/raw_sql.rs index f4104348bc..e465d08108 100644 --- a/sqlx-core/src/raw_sql.rs +++ b/sqlx-core/src/raw_sql.rs @@ -5,6 +5,7 @@ use futures_core::stream::BoxStream; use crate::database::Database; use crate::error::BoxDynError; use crate::executor::{Execute, Executor}; +use crate::sql_str::{SqlSafeStr, SqlStr}; use crate::Error; // AUTHOR'S NOTE: I was just going to call this API `sql()` and `Sql`, respectively, @@ -16,7 +17,7 @@ use crate::Error; /// One or more raw SQL statements, separated by semicolons (`;`). /// /// See [`raw_sql()`] for details. -pub struct RawSql<'q>(&'q str); +pub struct RawSql(SqlStr); /// Execute one or more statements as raw SQL, separated by semicolons (`;`). /// @@ -115,16 +116,16 @@ pub struct RawSql<'q>(&'q str); /// /// See [MySQL manual, section 13.3.3: Statements That Cause an Implicit Commit](https://dev.mysql.com/doc/refman/8.0/en/implicit-commit.html) for details. /// See also: [MariaDB manual: SQL statements That Cause an Implicit Commit](https://mariadb.com/kb/en/sql-statements-that-cause-an-implicit-commit/). -pub fn raw_sql(sql: &str) -> RawSql<'_> { - RawSql(sql) +pub fn raw_sql(sql: impl SqlSafeStr) -> RawSql { + RawSql(sql.into_sql_str()) } -impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { - fn sql(&self) -> &'q str { +impl<'q, DB: Database> Execute<'q, DB> for RawSql { + fn sql(self) -> SqlStr { self.0 } - fn statement(&self) -> Option<&::Statement<'q>> { + fn statement(&self) -> Option<&::Statement> { None } @@ -137,12 +138,11 @@ impl<'q, DB: Database> Execute<'q, DB> for RawSql<'q> { } } -impl<'q> RawSql<'q> { +impl RawSql { /// Execute the SQL string and return the total number of rows affected. #[inline] pub async fn execute<'e, E, DB>(self, executor: E) -> crate::Result where - 'q: 'e, DB: Database, E: Executor<'e, Database = DB>, { @@ -156,7 +156,6 @@ impl<'q> RawSql<'q> { executor: E, ) -> BoxStream<'e, crate::Result> where - 'q: 'e, DB: Database, E: Executor<'e, Database = DB>, { @@ -169,7 +168,6 @@ impl<'q> RawSql<'q> { #[inline] pub fn fetch<'e, E, DB>(self, executor: E) -> BoxStream<'e, Result> where - 'q: 'e, DB: Database, E: Executor<'e, Database = DB>, { @@ -186,7 +184,6 @@ impl<'q> RawSql<'q> { executor: E, ) -> BoxStream<'e, Result, Error>> where - 'q: 'e, DB: Database, E: Executor<'e, Database = DB>, { @@ -203,7 +200,6 @@ impl<'q> RawSql<'q> { #[inline] pub fn fetch_all<'e, E, DB>(self, executor: E) -> BoxFuture<'e, crate::Result>> where - 'q: 'e, DB: Database, E: Executor<'e, Database = DB>, { @@ -225,7 +221,6 @@ impl<'q> RawSql<'q> { #[inline] pub fn fetch_one<'e, E, DB>(self, executor: E) -> BoxFuture<'e, crate::Result> where - 'q: 'e, DB: Database, E: Executor<'e, Database = DB>, { @@ -247,7 +242,6 @@ impl<'q> RawSql<'q> { #[inline] pub async fn fetch_optional<'e, E, DB>(self, executor: E) -> crate::Result where - 'q: 'e, DB: Database, E: Executor<'e, Database = DB>, { diff --git a/sqlx-core/src/rt/mod.rs b/sqlx-core/src/rt/mod.rs index 2a21daff96..696d51d43d 100644 --- a/sqlx-core/src/rt/mod.rs +++ b/sqlx-core/src/rt/mod.rs @@ -29,6 +29,7 @@ pub async fn timeout(duration: Duration, f: F) -> Result SqlStr; +} + +impl SqlSafeStr for &'static str { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Static(self)) + } +} + +/// Assert that a query string is safe to execute on a database connection. +/// +/// Using this API means that **you** have made sure that the string contents do not contain a +/// [SQL injection vulnerability][injection]. It means that, if the string was constructed +/// dynamically, and/or from user input, you have taken care to sanitize the input yourself. +/// SQLx does not provide any sort of sanitization; the design of SQLx prefers the use +/// of prepared statements for dynamic input. +/// +/// The maintainers of SQLx take no responsibility for any data leaks or loss resulting from misuse +/// of this API. **Use at your own risk.** +/// +/// Note that `&'static str` implements [`SqlSafeStr`] directly and so does not need to be wrapped +/// with this type. +/// +/// [injection]: https://en.wikipedia.org/wiki/SQL_injection +pub struct AssertSqlSafe(pub T); + +/// Note: copies the string. +/// +/// It is recommended to pass one of the supported owned string types instead. +impl SqlSafeStr for AssertSqlSafe<&str> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Arced(self.0.into())) + } +} +impl SqlSafeStr for AssertSqlSafe { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Owned(self.0)) + } +} + +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Boxed(self.0)) + } +} + +// Note: this is not implemented for `Rc` because it would make `QueryString: !Send`. +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::Arced(self.0)) + } +} + +impl SqlSafeStr for AssertSqlSafe> { + #[inline] + fn into_sql_str(self) -> SqlStr { + SqlStr(Repr::ArcString(self.0)) + } +} + +impl SqlSafeStr for AssertSqlSafe> { + fn into_sql_str(self) -> SqlStr { + match self.0 { + Cow::Borrowed(str) => str.into_sql_str(), + Cow::Owned(str) => AssertSqlSafe(str).into_sql_str(), + } + } +} + +/// A SQL string that is ready to execute on a database connection. +/// +/// This is essentially `Cow<'static, str>` but which can be constructed from additional types +/// without copying. +/// +/// See [`SqlSafeStr`] for details. +#[derive(Debug)] +pub struct SqlStr(Repr); + +#[derive(Debug)] +enum Repr { + /// We need a variant to memoize when we already have a static string, so we don't copy it. + Static(&'static str), + /// Thanks to the new niche in `String`, this doesn't increase the size beyond 3 words. + /// We essentially get all these variants for free. + Owned(String), + Boxed(Box), + Arced(Arc), + /// Allows for dynamic shared ownership with `query_builder`. + ArcString(Arc), +} + +impl Clone for SqlStr { + fn clone(&self) -> Self { + Self(match &self.0 { + Repr::Static(s) => Repr::Static(s), + Repr::Arced(s) => Repr::Arced(s.clone()), + // If `.clone()` gets called once, assume it might get called again. + _ => Repr::Arced(self.as_str().into()), + }) + } +} + +impl SqlSafeStr for SqlStr { + #[inline] + fn into_sql_str(self) -> SqlStr { + self + } +} + +impl SqlStr { + /// Borrow the inner query string. + #[inline] + pub fn as_str(&self) -> &str { + match &self.0 { + Repr::Static(s) => s, + Repr::Owned(s) => s, + Repr::Boxed(s) => s, + Repr::Arced(s) => s, + Repr::ArcString(s) => s, + } + } + + pub const fn from_static(sql: &'static str) -> Self { + SqlStr(Repr::Static(sql)) + } +} + +impl AsRef for SqlStr { + #[inline] + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl Borrow for SqlStr { + #[inline] + fn borrow(&self) -> &str { + self.as_str() + } +} + +impl PartialEq for SqlStr +where + T: AsRef, +{ + fn eq(&self, other: &T) -> bool { + self.as_str() == other.as_ref() + } +} + +impl Eq for SqlStr {} + +impl Hash for SqlStr { + fn hash(&self, state: &mut H) { + self.as_str().hash(state) + } +} diff --git a/sqlx-core/src/statement.rs b/sqlx-core/src/statement.rs index 17dfd6428d..76d0325639 100644 --- a/sqlx-core/src/statement.rs +++ b/sqlx-core/src/statement.rs @@ -6,6 +6,7 @@ use crate::from_row::FromRow; use crate::query::Query; use crate::query_as::QueryAs; use crate::query_scalar::QueryScalar; +use crate::sql_str::SqlStr; use either::Either; /// An explicitly prepared statement. @@ -16,15 +17,14 @@ use either::Either; /// /// Statements can be re-used with any connection and on first-use it will be re-prepared and /// cached within the connection. -pub trait Statement<'q>: Send + Sync { +pub trait Statement: Send + Sync + Clone { type Database: Database; - /// Creates an owned statement from this statement reference. This copies - /// the original SQL text. - fn to_owned(&self) -> ::Statement<'static>; + /// Get the original SQL text used to create this statement. + fn into_sql(self) -> SqlStr; /// Get the original SQL text used to create this statement. - fn sql(&self) -> &str; + fn sql(&self) -> &SqlStr; /// Get the expected parameters for this statement. /// diff --git a/sqlx-core/src/testing/fixtures.rs b/sqlx-core/src/testing/fixtures.rs index 67670d8014..32fdfe2219 100644 --- a/sqlx-core/src/testing/fixtures.rs +++ b/sqlx-core/src/testing/fixtures.rs @@ -150,7 +150,7 @@ where } } - query.into_sql() + query.into_string() } } diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index d8f350c73e..917690339e 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::future::{self, Future}; use std::ops::{Deref, DerefMut}; @@ -8,6 +7,7 @@ use futures_core::future::BoxFuture; use crate::database::Database; use crate::error::Error; use crate::pool::MaybePoolConnection; +use crate::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; /// Generic management of database transactions. /// @@ -22,10 +22,10 @@ pub trait TransactionManager { /// /// If we are already inside a transaction and `statement.is_some()`, then /// `Error::InvalidSavePoint` is returned without running any statements. - fn begin<'conn>( - conn: &'conn mut ::Connection, - statement: Option>, - ) -> impl Future> + Send + 'conn; + fn begin( + conn: &mut ::Connection, + statement: Option, + ) -> impl Future> + Send + '_; /// Commit the active transaction or release the most recent savepoint. fn commit( @@ -98,7 +98,7 @@ where #[doc(hidden)] pub fn begin( conn: impl Into>, - statement: Option>, + statement: Option, ) -> BoxFuture<'c, Result> { let mut conn = conn.into(); @@ -274,29 +274,30 @@ where } } -pub fn begin_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn begin_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 0 { - Cow::Borrowed("BEGIN") + "BEGIN".into_sql_str() } else { - Cow::Owned(format!("SAVEPOINT _sqlx_savepoint_{depth}")) + AssertSqlSafe(format!("SAVEPOINT _sqlx_savepoint_{depth}")).into_sql_str() } } -pub fn commit_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn commit_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 1 { - Cow::Borrowed("COMMIT") + "COMMIT".into_sql_str() } else { - Cow::Owned(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)) + AssertSqlSafe(format!("RELEASE SAVEPOINT _sqlx_savepoint_{}", depth - 1)).into_sql_str() } } -pub fn rollback_ansi_transaction_sql(depth: usize) -> Cow<'static, str> { +pub fn rollback_ansi_transaction_sql(depth: usize) -> SqlStr { if depth == 1 { - Cow::Borrowed("ROLLBACK") + "ROLLBACK".into_sql_str() } else { - Cow::Owned(format!( + AssertSqlSafe(format!( "ROLLBACK TO SAVEPOINT _sqlx_savepoint_{}", depth - 1 )) + .into_sql_str() } } diff --git a/sqlx-macros-core/src/database/mod.rs b/sqlx-macros-core/src/database/mod.rs index 50cd516c22..311dedf4d3 100644 --- a/sqlx-macros-core/src/database/mod.rs +++ b/sqlx-macros-core/src/database/mod.rs @@ -6,6 +6,8 @@ use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::describe::Describe; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::AssertSqlSafe; +use sqlx_core::sql_str::SqlSafeStr; use sqlx_core::type_checking::TypeChecking; #[cfg(any(feature = "postgres", feature = "mysql", feature = "_sqlite"))] @@ -77,7 +79,8 @@ impl CachingDescribeBlocking { } }; - conn.describe(query).await + conn.describe(AssertSqlSafe(query.to_string()).into_sql_str()) + .await }) } } diff --git a/sqlx-macros-core/src/migrate.rs b/sqlx-macros-core/src/migrate.rs index 4f051d1330..b855703c22 100644 --- a/sqlx-macros-core/src/migrate.rs +++ b/sqlx-macros-core/src/migrate.rs @@ -71,7 +71,7 @@ impl ToTokens for QuoteMigration { version: #version, description: ::std::borrow::Cow::Borrowed(#description), migration_type: #migration_type, - sql: ::std::borrow::Cow::Borrowed(#sql), + sql: ::sqlx::SqlStr::from_static(#sql), no_tx: #no_tx, checksum: ::std::borrow::Cow::Borrowed(&[ #(#checksum),* @@ -137,9 +137,9 @@ pub fn expand_with_path(config: &Config, path: &Path) -> crate::Result>, - ) -> BoxFuture<'_, sqlx_core::Result<()>> { + fn begin(&mut self, statement: Option) -> BoxFuture<'_, sqlx_core::Result<()>> { MySqlTransactionManager::begin(self, statement).boxed() } @@ -82,7 +79,7 @@ impl AnyConnectionBackend for MySqlConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -108,7 +105,7 @@ impl AnyConnectionBackend for MySqlConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -135,20 +132,17 @@ impl AnyConnectionBackend for MySqlConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; - AnyStatement::try_from_statement( - sql, - &statement, - statement.metadata.column_names.clone(), - ) + let column_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; describe.try_into_any() diff --git a/sqlx-mysql/src/connection/executor.rs b/sqlx-mysql/src/connection/executor.rs index e29b231db5..4eae4bc569 100644 --- a/sqlx-mysql/src/connection/executor.rs +++ b/sqlx-mysql/src/connection/executor.rs @@ -23,7 +23,8 @@ use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::TryStreamExt; use sqlx_core::column::{ColumnOrigin, TableColumn}; -use std::{borrow::Cow, pin::pin, sync::Arc}; +use sqlx_core::sql_str::SqlStr; +use std::{pin::pin, sync::Arc}; impl MySqlConnection { async fn prepare_statement( @@ -102,13 +103,11 @@ impl MySqlConnection { #[allow(clippy::needless_lifetimes)] pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( &'c mut self, - sql: &'q str, + sql: SqlStr, arguments: Option, persistent: bool, ) -> Result, Error>> + 'e, Error> { - let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone()); - self.inner.stream.wait_until_ready().await?; self.inner.stream.waiting.push_back(Waiting::Result); @@ -121,7 +120,7 @@ impl MySqlConnection { let (mut column_names, format, mut needs_metadata) = if let Some(arguments) = arguments { if persistent && self.inner.cache_statement.is_enabled() { let (id, metadata) = self - .get_or_prepare_statement(sql) + .get_or_prepare_statement(sql.as_str()) .await?; if arguments.types.len() != metadata.parameters { @@ -145,7 +144,7 @@ impl MySqlConnection { (metadata.column_names, MySqlValueFormat::Binary, false) } else { let (id, metadata) = self - .prepare_statement(sql) + .prepare_statement(sql.as_str()) .await?; if arguments.types.len() != metadata.parameters { @@ -172,10 +171,11 @@ impl MySqlConnection { } } else { // https://dev.mysql.com/doc/internals/en/com-query.html - self.inner.stream.send_packet(Query(sql)).await?; + self.inner.stream.send_packet(Query(sql.as_str())).await?; (Arc::default(), MySqlValueFormat::Text, true) }; + let mut logger = QueryLogger::new(sql, self.inner.log_settings.clone()); loop { // query response is a meta-packet which may be one of: @@ -287,11 +287,11 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = query.take_arguments().map_err(Error::Encode); let persistent = query.persistent(); Box::pin(try_stream! { + let sql = query.sql(); let arguments = arguments?; let mut s = pin!(self.run(sql, arguments, persistent).await?); @@ -323,11 +323,11 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: SqlStr, _parameters: &'e [MySqlTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { @@ -335,9 +335,9 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { self.inner.stream.wait_until_ready().await?; let metadata = if self.inner.cache_statement.is_enabled() { - self.get_or_prepare_statement(sql).await?.1 + self.get_or_prepare_statement(sql.as_str()).await?.1 } else { - let (id, metadata) = self.prepare_statement(sql).await?; + let (id, metadata) = self.prepare_statement(sql.as_str()).await?; self.inner .stream @@ -348,7 +348,7 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { }; Ok(MySqlStatement { - sql: Cow::Borrowed(sql), + sql, // metadata has internal Arcs for expensive data structures metadata: metadata.clone(), }) @@ -356,14 +356,14 @@ impl<'c> Executor<'c> for &'c mut MySqlConnection { } #[doc(hidden)] - fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { Box::pin(async move { self.inner.stream.wait_until_ready().await?; - let (id, metadata) = self.prepare_statement(sql).await?; + let (id, metadata) = self.prepare_statement(sql.as_str()).await?; self.inner .stream diff --git a/sqlx-mysql/src/connection/mod.rs b/sqlx-mysql/src/connection/mod.rs index 26613a31d1..c70d67f3d1 100644 --- a/sqlx-mysql/src/connection/mod.rs +++ b/sqlx-mysql/src/connection/mod.rs @@ -1,8 +1,8 @@ -use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::future::Future; pub(crate) use sqlx_core::connection::*; +use sqlx_core::sql_str::SqlSafeStr; pub(crate) use stream::{MySqlStream, Waiting}; use crate::common::StatementCache; @@ -117,12 +117,12 @@ impl Connection for MySqlConnection { fn begin_with( &mut self, - statement: impl Into>, + statement: impl SqlSafeStr, ) -> impl Future, Error>> + Send + '_ where Self: Sized, { - Transaction::begin(self, Some(statement.into())) + Transaction::begin(self, Some(statement.into_sql_str())) } fn shrink_buffers(&mut self) { diff --git a/sqlx-mysql/src/database.rs b/sqlx-mysql/src/database.rs index d03a567284..0e3f51f532 100644 --- a/sqlx-mysql/src/database.rs +++ b/sqlx-mysql/src/database.rs @@ -28,7 +28,7 @@ impl Database for MySql { type Arguments<'q> = MySqlArguments; type ArgumentBuffer<'q> = Vec; - type Statement<'q> = MySqlStatement<'q>; + type Statement = MySqlStatement; const NAME: &'static str = "MySQL"; diff --git a/sqlx-mysql/src/migrate.rs b/sqlx-mysql/src/migrate.rs index c8803e3b8c..0176f93c26 100644 --- a/sqlx-mysql/src/migrate.rs +++ b/sqlx-mysql/src/migrate.rs @@ -2,6 +2,10 @@ use std::str::FromStr; use std::time::Duration; use std::time::Instant; +use futures_core::future::BoxFuture; +pub(crate) use sqlx_core::migrate::*; +use sqlx_core::sql_str::AssertSqlSafe; + use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; use crate::executor::Executor; @@ -9,8 +13,6 @@ use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; use crate::{MySql, MySqlConnectOptions, MySqlConnection}; -use futures_core::future::BoxFuture; -pub(crate) use sqlx_core::migrate::*; fn parse_for_maintenance(url: &str) -> Result<(MySqlConnectOptions, String), Error> { let mut options = MySqlConnectOptions::from_str(url)?; @@ -35,7 +37,7 @@ impl MigrateDatabase for MySql { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!("CREATE DATABASE `{database}`")) + .execute(AssertSqlSafe(format!("CREATE DATABASE `{database}`"))) .await?; Ok(()) @@ -60,7 +62,9 @@ impl MigrateDatabase for MySql { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!("DROP DATABASE IF EXISTS `{database}`")) + .execute(AssertSqlSafe(format!( + "DROP DATABASE IF EXISTS `{database}`" + ))) .await?; Ok(()) @@ -74,8 +78,10 @@ impl Migrate for MySqlConnection { ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQL - self.execute(&*format!(r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#)) - .await?; + self.execute(AssertSqlSafe(format!( + r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"# + ))) + .await?; Ok(()) }) @@ -87,7 +93,7 @@ impl Migrate for MySqlConnection { ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=MySQL - self.execute(&*format!( + self.execute(AssertSqlSafe(format!( r#" CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, @@ -98,7 +104,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( execution_time BIGINT NOT NULL ); "# - )) + ))) .await?; Ok(()) @@ -111,9 +117,9 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row: Option<(i64,)> = query_as(&format!( + let row: Option<(i64,)> = query_as(AssertSqlSafe(format!( "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" - )) + ))) .fetch_optional(self) .await?; @@ -127,9 +133,9 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let rows: Vec<(i64, Vec)> = query_as(&format!( + let rows: Vec<(i64, Vec)> = query_as(AssertSqlSafe(format!( "SELECT version, checksum FROM {table_name} ORDER BY version" - )) + ))) .fetch_all(self) .await?; @@ -202,12 +208,12 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // `success=FALSE` and later modify the flag. // // language=MySQL - let _ = query(&format!( + let _ = query(AssertSqlSafe(format!( r#" INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( ?, ?, FALSE, ?, -1 ) "# - )) + ))) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) @@ -215,18 +221,18 @@ CREATE TABLE IF NOT EXISTS {table_name} ( .await?; let _ = tx - .execute(&*migration.sql) + .execute(migration.sql.clone()) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=MySQL - let _ = query(&format!( + let _ = query(AssertSqlSafe(format!( r#" UPDATE {table_name} SET success = TRUE WHERE version = ? "# - )) + ))) .bind(migration.version) .execute(&mut *tx) .await?; @@ -240,13 +246,13 @@ CREATE TABLE IF NOT EXISTS {table_name} ( let elapsed = start.elapsed(); #[allow(clippy::cast_possible_truncation)] - let _ = query(&format!( + let _ = query(AssertSqlSafe(format!( r#" UPDATE {table_name} SET execution_time = ? WHERE version = ? "# - )) + ))) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) @@ -274,24 +280,26 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // `success=FALSE` and later remove the migration altogether. // // language=MySQL - let _ = query(&format!( + let _ = query(AssertSqlSafe(format!( r#" UPDATE {table_name} SET success = FALSE WHERE version = ? "# - )) + ))) .bind(migration.version) .execute(&mut *tx) .await?; - tx.execute(&*migration.sql).await?; + tx.execute(migration.sql.clone()).await?; // language=SQL - let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = ?"#)) - .bind(migration.version) - .execute(&mut *tx) - .await?; + let _ = query(AssertSqlSafe(format!( + r#"DELETE FROM {table_name} WHERE version = ?"# + ))) + .bind(migration.version) + .execute(&mut *tx) + .await?; tx.commit().await?; diff --git a/sqlx-mysql/src/options/connect.rs b/sqlx-mysql/src/options/connect.rs index 58bf21ac2f..295a8b84a5 100644 --- a/sqlx-mysql/src/options/connect.rs +++ b/sqlx-mysql/src/options/connect.rs @@ -3,6 +3,7 @@ use crate::error::Error; use crate::executor::Executor; use crate::{MySqlConnectOptions, MySqlConnection}; use log::LevelFilter; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::Url; use std::time::Duration; @@ -75,7 +76,7 @@ impl ConnectOptions for MySqlConnectOptions { } if !options.is_empty() { - conn.execute(&*format!(r#"SET {};"#, options.join(","))) + conn.execute(AssertSqlSafe(format!(r#"SET {};"#, options.join(",")))) .await?; } diff --git a/sqlx-mysql/src/statement.rs b/sqlx-mysql/src/statement.rs index e9578403e1..711c8270d0 100644 --- a/sqlx-mysql/src/statement.rs +++ b/sqlx-mysql/src/statement.rs @@ -5,14 +5,14 @@ use crate::ext::ustr::UStr; use crate::HashMap; use crate::{MySql, MySqlArguments, MySqlTypeInfo}; use either::Either; -use std::borrow::Cow; +use sqlx_core::sql_str::SqlStr; use std::sync::Arc; pub(crate) use sqlx_core::statement::*; #[derive(Debug, Clone)] -pub struct MySqlStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct MySqlStatement { + pub(crate) sql: SqlStr, pub(crate) metadata: MySqlStatementMetadata, } @@ -23,17 +23,14 @@ pub(crate) struct MySqlStatementMetadata { pub(crate) parameters: usize, } -impl<'q> Statement<'q> for MySqlStatement<'q> { +impl Statement for MySqlStatement { type Database = MySql; - fn to_owned(&self) -> MySqlStatement<'static> { - MySqlStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), - metadata: self.metadata.clone(), - } + fn into_sql(self) -> SqlStr { + self.sql } - fn sql(&self) -> &str { + fn sql(&self) -> &SqlStr { &self.sql } @@ -48,8 +45,8 @@ impl<'q> Statement<'q> for MySqlStatement<'q> { impl_statement_query!(MySqlArguments); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &MySqlStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &MySqlStatement) -> Result { statement .metadata .column_names diff --git a/sqlx-mysql/src/testing/mod.rs b/sqlx-mysql/src/testing/mod.rs index 831403e7a5..f509f9da45 100644 --- a/sqlx-mysql/src/testing/mod.rs +++ b/sqlx-mysql/src/testing/mod.rs @@ -12,7 +12,7 @@ use crate::{MySql, MySqlConnectOptions, MySqlConnection, MySqlDatabaseError}; use sqlx_core::connection::Connection; use sqlx_core::query_builder::QueryBuilder; use sqlx_core::query_scalar::query_scalar; -use std::fmt::Write; +use sqlx_core::sql_str::AssertSqlSafe; pub(crate) use sqlx_core::testing::*; @@ -51,13 +51,12 @@ impl TestSupport for MySql { let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); - let mut command = String::new(); + let mut builder = QueryBuilder::new("drop database if exists "); for db_name in &delete_db_names { - command.clear(); + builder.push(db_name); - writeln!(command, "drop database if exists {db_name};").ok(); - match conn.execute(&*command).await { + match builder.build().execute(&mut conn).await { Ok(_deleted) => { deleted_db_names.push(db_name); } @@ -68,6 +67,8 @@ impl TestSupport for MySql { // Bubble up other errors Err(e) => return Err(e), } + + builder.reset(); } if deleted_db_names.is_empty() { @@ -157,7 +158,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { .execute(&mut *conn) .await?; - conn.execute(&format!("create database {db_name}")[..]) + conn.execute(AssertSqlSafe(format!("create database {db_name}"))) .await?; eprintln!("created database {db_name}"); @@ -182,7 +183,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { async fn do_cleanup(conn: &mut MySqlConnection, db_name: &str) -> Result<(), Error> { let delete_db_command = format!("drop database if exists {db_name};"); - conn.execute(&*delete_db_command).await?; + conn.execute(AssertSqlSafe(delete_db_command)).await?; query("delete from _sqlx_test_databases where db_name = ?") .bind(db_name) .execute(&mut *conn) @@ -221,9 +222,9 @@ async fn cleanup_old_dbs(conn: &mut MySqlConnection) -> Result<(), Error> { // Drop old-style test databases. for id in db_ids { match conn - .execute(&*format!( + .execute(AssertSqlSafe(format!( "drop database if exists _sqlx_test_database_{id}" - )) + ))) .await { Ok(_deleted) => (), diff --git a/sqlx-mysql/src/transaction.rs b/sqlx-mysql/src/transaction.rs index f17ef85bfd..18db30b183 100644 --- a/sqlx-mysql/src/transaction.rs +++ b/sqlx-mysql/src/transaction.rs @@ -1,4 +1,4 @@ -use std::borrow::Cow; +use sqlx_core::sql_str::SqlStr; use crate::connection::Waiting; use crate::error::Error; @@ -14,11 +14,9 @@ pub struct MySqlTransactionManager; impl TransactionManager for MySqlTransactionManager { type Database = MySql; - async fn begin( - conn: &mut MySqlConnection, - statement: Option>, - ) -> Result<(), Error> { + async fn begin(conn: &mut MySqlConnection, statement: Option) -> Result<(), Error> { let depth = conn.inner.transaction_depth; + let statement = match statement { // custom `BEGIN` statements are not allowed if we're already in a transaction // (we need to issue a `SAVEPOINT` instead) @@ -26,7 +24,7 @@ impl TransactionManager for MySqlTransactionManager { Some(statement) => statement, None => begin_ansi_transaction_sql(depth), }; - conn.execute(&*statement).await?; + conn.execute(statement).await?; if !conn.in_transaction() { return Err(Error::BeginFailed); } @@ -39,7 +37,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.inner.transaction_depth; if depth > 0 { - conn.execute(&*commit_ansi_transaction_sql(depth)).await?; + conn.execute(commit_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth - 1; } @@ -50,7 +48,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.inner.transaction_depth; if depth > 0 { - conn.execute(&*rollback_ansi_transaction_sql(depth)).await?; + conn.execute(rollback_ansi_transaction_sql(depth)).await?; conn.inner.transaction_depth = depth - 1; } @@ -65,7 +63,7 @@ impl TransactionManager for MySqlTransactionManager { conn.inner.stream.sequence_id = 0; conn.inner .stream - .write_packet(Query(&rollback_ansi_transaction_sql(depth))) + .write_packet(Query(rollback_ansi_transaction_sql(depth).as_str())) .expect("BUG: unexpected error queueing ROLLBACK"); conn.inner.transaction_depth = depth - 1; diff --git a/sqlx-postgres/src/any.rs b/sqlx-postgres/src/any.rs index 96eadbe66c..75ee0d73df 100644 --- a/sqlx-postgres/src/any.rs +++ b/sqlx-postgres/src/any.rs @@ -5,7 +5,7 @@ use crate::{ use futures_core::future::BoxFuture; use futures_core::stream::BoxStream; use futures_util::{stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; -use std::borrow::Cow; +use sqlx_core::sql_str::SqlStr; use std::{future, pin::pin}; use sqlx_core::any::{ @@ -40,10 +40,7 @@ impl AnyConnectionBackend for PgConnection { Connection::ping(self).boxed() } - fn begin( - &mut self, - statement: Option>, - ) -> BoxFuture<'_, sqlx_core::Result<()>> { + fn begin(&mut self, statement: Option) -> BoxFuture<'_, sqlx_core::Result<()>> { PgTransactionManager::begin(self, statement).boxed() } @@ -84,7 +81,7 @@ impl AnyConnectionBackend for PgConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -110,7 +107,7 @@ impl AnyConnectionBackend for PgConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -135,20 +132,17 @@ impl AnyConnectionBackend for PgConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; - AnyStatement::try_from_statement( - sql, - &statement, - statement.metadata.column_names.clone(), - ) + let colunn_names = statement.metadata.column_names.clone(); + AnyStatement::try_from_statement(statement, colunn_names) }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe<'c>(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { let describe = Executor::describe(self, sql).await?; diff --git a/sqlx-postgres/src/connection/describe.rs b/sqlx-postgres/src/connection/describe.rs index 0334357a6c..dfe5286458 100644 --- a/sqlx-postgres/src/connection/describe.rs +++ b/sqlx-postgres/src/connection/describe.rs @@ -14,6 +14,7 @@ use crate::{PgColumn, PgConnection, PgTypeInfo}; use smallvec::SmallVec; use sqlx_core::column::{ColumnOrigin, TableColumn}; use sqlx_core::query_builder::QueryBuilder; +use sqlx_core::sql_str::AssertSqlSafe; use std::sync::Arc; /// Describes the type of the `pg_type.typtype` column @@ -619,7 +620,7 @@ WHERE rngtypid = $1 } let (Json(explains),): (Json>,) = - query_as(&explain).fetch_one(self).await?; + query_as(AssertSqlSafe(explain)).fetch_one(self).await?; let mut nullables = Vec::new(); diff --git a/sqlx-postgres/src/connection/executor.rs b/sqlx-postgres/src/connection/executor.rs index 93cf4ec6bc..c3862a3f74 100644 --- a/sqlx-postgres/src/connection/executor.rs +++ b/sqlx-postgres/src/connection/executor.rs @@ -17,8 +17,9 @@ use futures_core::stream::BoxStream; use futures_core::Stream; use futures_util::TryStreamExt; use sqlx_core::arguments::Arguments; +use sqlx_core::sql_str::SqlStr; use sqlx_core::Either; -use std::{borrow::Cow, pin::pin, sync::Arc}; +use std::{pin::pin, sync::Arc}; async fn prepare( conn: &mut PgConnection, @@ -209,13 +210,11 @@ impl PgConnection { pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>( &'c mut self, - query: &'q str, + query: SqlStr, arguments: Option, persistent: bool, metadata_opt: Option>, ) -> Result, Error>> + 'e, Error> { - let mut logger = QueryLogger::new(query, self.inner.log_settings.clone()); - // before we continue, wait until we are "ready" to accept more queries self.wait_until_ready().await?; @@ -238,7 +237,13 @@ impl PgConnection { // prepare the statement if this our first time executing it // always return the statement ID here let (statement, metadata_) = self - .get_or_prepare(query, &arguments.types, persistent, metadata_opt, false) + .get_or_prepare( + query.as_str(), + &arguments.types, + persistent, + metadata_opt, + false, + ) .await?; metadata = metadata_; @@ -291,7 +296,7 @@ impl PgConnection { PgValueFormat::Binary } else { // Query will trigger a ReadyForQuery - self.inner.stream.write_msg(Query(query))?; + self.inner.stream.write_msg(Query(query.as_str()))?; self.inner.pending_ready_for_query_count += 1; // metadata starts out as "nothing" @@ -302,6 +307,7 @@ impl PgConnection { }; self.inner.stream.flush().await?; + let mut logger = QueryLogger::new(query, self.inner.log_settings.clone()); Ok(try_stream! { loop { @@ -402,12 +408,12 @@ impl<'c> Executor<'c> for &'c mut PgConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); // False positive: https://github.com/rust-lang/rust-clippy/issues/12560 #[allow(clippy::map_clone)] let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); let arguments = query.take_arguments().map_err(Error::Encode); let persistent = query.persistent(); + let sql = query.sql(); Box::pin(try_stream! { let arguments = arguments?; @@ -428,7 +434,6 @@ impl<'c> Executor<'c> for &'c mut PgConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); // False positive: https://github.com/rust-lang/rust-clippy/issues/12560 #[allow(clippy::map_clone)] let metadata = query.statement().map(|s| Arc::clone(&s.metadata)); @@ -436,6 +441,7 @@ impl<'c> Executor<'c> for &'c mut PgConnection { let persistent = query.persistent(); Box::pin(async move { + let sql = query.sql(); let arguments = arguments?; let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?); @@ -454,11 +460,11 @@ impl<'c> Executor<'c> for &'c mut PgConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: SqlStr, parameters: &'e [PgTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { @@ -466,27 +472,23 @@ impl<'c> Executor<'c> for &'c mut PgConnection { self.wait_until_ready().await?; let (_, metadata) = self - .get_or_prepare(sql, parameters, true, None, true) + .get_or_prepare(sql.as_str(), parameters, true, None, true) .await?; - Ok(PgStatement { - sql: Cow::Borrowed(sql), - metadata, - }) + Ok(PgStatement { sql, metadata }) }) } - fn describe<'e, 'q: 'e>( - self, - sql: &'q str, - ) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { Box::pin(async move { self.wait_until_ready().await?; - let (stmt_id, metadata) = self.get_or_prepare(sql, &[], true, None, true).await?; + let (stmt_id, metadata) = self + .get_or_prepare(sql.as_str(), &[], true, None, true) + .await?; let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?; diff --git a/sqlx-postgres/src/connection/mod.rs b/sqlx-postgres/src/connection/mod.rs index 5d2bcea31b..4e05cd867b 100644 --- a/sqlx-postgres/src/connection/mod.rs +++ b/sqlx-postgres/src/connection/mod.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::collections::BTreeMap; use std::fmt::{self, Debug, Formatter}; use std::future::Future; @@ -20,6 +19,7 @@ use crate::types::Oid; use crate::{PgConnectOptions, PgTypeInfo, Postgres}; pub(crate) use sqlx_core::connection::*; +use sqlx_core::sql_str::SqlSafeStr; pub use self::stream::PgStream; @@ -193,12 +193,12 @@ impl Connection for PgConnection { fn begin_with( &mut self, - statement: impl Into>, + statement: impl SqlSafeStr, ) -> impl Future, Error>> + Send + '_ where Self: Sized, { - Transaction::begin(self, Some(statement.into())) + Transaction::begin(self, Some(statement.into_sql_str())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-postgres/src/database.rs b/sqlx-postgres/src/database.rs index 876e295899..fbc762615b 100644 --- a/sqlx-postgres/src/database.rs +++ b/sqlx-postgres/src/database.rs @@ -30,7 +30,7 @@ impl Database for Postgres { type Arguments<'q> = PgArguments; type ArgumentBuffer<'q> = PgArgumentBuffer; - type Statement<'q> = PgStatement<'q>; + type Statement = PgStatement; const NAME: &'static str = "PostgreSQL"; diff --git a/sqlx-postgres/src/listener.rs b/sqlx-postgres/src/listener.rs index 32658534c4..639ec95441 100644 --- a/sqlx-postgres/src/listener.rs +++ b/sqlx-postgres/src/listener.rs @@ -7,6 +7,7 @@ use futures_core::future::BoxFuture; use futures_core::stream::{BoxStream, Stream}; use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::acquire::Acquire; +use sqlx_core::sql_str::{AssertSqlSafe, SqlStr}; use sqlx_core::transaction::Transaction; use sqlx_core::Either; use tracing::Instrument; @@ -116,7 +117,7 @@ impl PgListener { pub async fn listen(&mut self, channel: &str) -> Result<(), Error> { self.connection() .await? - .execute(&*format!(r#"LISTEN "{}""#, ident(channel))) + .execute(AssertSqlSafe(format!(r#"LISTEN "{}""#, ident(channel)))) .await?; self.channels.push(channel.to_owned()); @@ -133,7 +134,10 @@ impl PgListener { self.channels.extend(channels.into_iter().map(|s| s.into())); let query = build_listen_all_query(&self.channels[beg..]); - self.connection().await?.execute(&*query).await?; + self.connection() + .await? + .execute(AssertSqlSafe(query)) + .await?; Ok(()) } @@ -145,7 +149,7 @@ impl PgListener { // UNLISTEN (we've disconnected anyways) if let Some(connection) = self.connection.as_mut() { connection - .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel))) + .execute(AssertSqlSafe(format!(r#"UNLISTEN "{}""#, ident(channel)))) .await?; } @@ -176,7 +180,7 @@ impl PgListener { connection.inner.stream.notifications = self.buffer_tx.take(); connection - .execute(&*build_listen_all_query(&self.channels)) + .execute(AssertSqlSafe(build_listen_all_query(&self.channels))) .await?; self.connection = Some(connection); @@ -417,11 +421,11 @@ impl<'c> Executor<'c> for &'c mut PgListener { async move { self.connection().await?.fetch_optional(query).await }.boxed() } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - query: &'q str, + query: SqlStr, parameters: &'e [PgTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { @@ -435,10 +439,7 @@ impl<'c> Executor<'c> for &'c mut PgListener { } #[doc(hidden)] - fn describe<'e, 'q: 'e>( - self, - query: &'q str, - ) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, query: SqlStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { diff --git a/sqlx-postgres/src/migrate.rs b/sqlx-postgres/src/migrate.rs index c0573ffa2c..b96c021be2 100644 --- a/sqlx-postgres/src/migrate.rs +++ b/sqlx-postgres/src/migrate.rs @@ -7,6 +7,7 @@ use futures_core::future::BoxFuture; pub(crate) use sqlx_core::migrate::MigrateError; pub(crate) use sqlx_core::migrate::{AppliedMigration, Migration}; pub(crate) use sqlx_core::migrate::{Migrate, MigrateDatabase}; +use sqlx_core::sql_str::AssertSqlSafe; use crate::connection::{ConnectOptions, Connection}; use crate::error::Error; @@ -44,10 +45,10 @@ impl MigrateDatabase for Postgres { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!( + .execute(AssertSqlSafe(format!( "CREATE DATABASE \"{}\"", database.replace('"', "\"\"") - )) + ))) .await?; Ok(()) @@ -71,10 +72,10 @@ impl MigrateDatabase for Postgres { let mut conn = options.connect().await?; let _ = conn - .execute(&*format!( + .execute(AssertSqlSafe(format!( "DROP DATABASE IF EXISTS \"{}\"", database.replace('"', "\"\"") - )) + ))) .await?; Ok(()) @@ -92,10 +93,10 @@ impl MigrateDatabase for Postgres { let pid_type = if version >= 90200 { "pid" } else { "procpid" }; - conn.execute(&*format!( + conn.execute(AssertSqlSafe(format!( "SELECT pg_terminate_backend(pg_stat_activity.{pid_type}) FROM pg_stat_activity \ WHERE pg_stat_activity.datname = '{database}' AND {pid_type} <> pg_backend_pid()" - )) + ))) .await?; Self::drop_database(url).await @@ -109,8 +110,10 @@ impl Migrate for PgConnection { ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQL - self.execute(&*format!(r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"#)) - .await?; + self.execute(AssertSqlSafe(format!( + r#"CREATE SCHEMA IF NOT EXISTS {schema_name};"# + ))) + .await?; Ok(()) }) @@ -122,7 +125,7 @@ impl Migrate for PgConnection { ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQL - self.execute(&*format!( + self.execute(AssertSqlSafe(format!( r#" CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, @@ -133,7 +136,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( execution_time BIGINT NOT NULL ); "# - )) + ))) .await?; Ok(()) @@ -146,9 +149,9 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let row: Option<(i64,)> = query_as(&format!( + let row: Option<(i64,)> = query_as(AssertSqlSafe(format!( "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" - )) + ))) .fetch_optional(self) .await?; @@ -162,9 +165,9 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQL - let rows: Vec<(i64, Vec)> = query_as(&format!( + let rows: Vec<(i64, Vec)> = query_as(AssertSqlSafe(format!( "SELECT version, checksum FROM {table_name} ORDER BY version" - )) + ))) .fetch_all(self) .await?; @@ -245,13 +248,13 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // language=SQL #[allow(clippy::cast_possible_truncation)] - let _ = query(&format!( + let _ = query(AssertSqlSafe(format!( r#" UPDATE {table_name} SET execution_time = $1 WHERE version = $2 "# - )) + ))) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) @@ -293,17 +296,17 @@ async fn execute_migration( migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn - .execute(&*migration.sql) + .execute(migration.sql.clone()) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query(&format!( + let _ = query(AssertSqlSafe(format!( r#" INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( $1, $2, TRUE, $3, -1 ) "# - )) + ))) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) @@ -319,15 +322,17 @@ async fn revert_migration( migration: &Migration, ) -> Result<(), MigrateError> { let _ = conn - .execute(&*migration.sql) + .execute(migration.sql.clone()) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = $1"#)) - .bind(migration.version) - .execute(conn) - .await?; + let _ = query(AssertSqlSafe(format!( + r#"DELETE FROM {table_name} WHERE version = $1"# + ))) + .bind(migration.version) + .execute(conn) + .await?; Ok(()) } diff --git a/sqlx-postgres/src/statement.rs b/sqlx-postgres/src/statement.rs index abd553af30..79be63440a 100644 --- a/sqlx-postgres/src/statement.rs +++ b/sqlx-postgres/src/statement.rs @@ -3,15 +3,15 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::{PgArguments, Postgres}; -use std::borrow::Cow; use std::sync::Arc; +use sqlx_core::sql_str::SqlStr; pub(crate) use sqlx_core::statement::Statement; use sqlx_core::{Either, HashMap}; #[derive(Debug, Clone)] -pub struct PgStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct PgStatement { + pub(crate) sql: SqlStr, pub(crate) metadata: Arc, } @@ -24,17 +24,14 @@ pub(crate) struct PgStatementMetadata { pub(crate) parameters: Vec, } -impl<'q> Statement<'q> for PgStatement<'q> { +impl Statement for PgStatement { type Database = Postgres; - fn to_owned(&self) -> PgStatement<'static> { - PgStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), - metadata: self.metadata.clone(), - } + fn into_sql(self) -> SqlStr { + self.sql } - fn sql(&self) -> &str { + fn sql(&self) -> &SqlStr { &self.sql } @@ -49,8 +46,8 @@ impl<'q> Statement<'q> for PgStatement<'q> { impl_statement_query!(PgArguments); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &PgStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &PgStatement) -> Result { statement .metadata .column_names diff --git a/sqlx-postgres/src/testing/mod.rs b/sqlx-postgres/src/testing/mod.rs index cd9574360b..3e1cf0ddf7 100644 --- a/sqlx-postgres/src/testing/mod.rs +++ b/sqlx-postgres/src/testing/mod.rs @@ -1,4 +1,3 @@ -use std::fmt::Write; use std::future::Future; use std::ops::Deref; use std::str::FromStr; @@ -6,7 +5,9 @@ use std::sync::OnceLock; use std::time::Duration; use sqlx_core::connection::Connection; +use sqlx_core::query_builder::QueryBuilder; use sqlx_core::query_scalar::query_scalar; +use sqlx_core::sql_str::AssertSqlSafe; use crate::error::Error; use crate::executor::Executor; @@ -52,12 +53,12 @@ impl TestSupport for Postgres { let mut deleted_db_names = Vec::with_capacity(delete_db_names.len()); - let mut command = String::new(); + let mut builder = QueryBuilder::new("drop database if exists "); for db_name in &delete_db_names { - command.clear(); - writeln!(command, "drop database if exists {db_name:?};").ok(); - match conn.execute(&*command).await { + builder.push(db_name); + + match builder.build().execute(&mut conn).await { Ok(_deleted) => { deleted_db_names.push(db_name); } @@ -68,6 +69,8 @@ impl TestSupport for Postgres { // Bubble up other errors Err(e) => return Err(e), } + + builder.reset(); } query("delete from _sqlx_test.databases where db_name = any($1::text[])") @@ -163,7 +166,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { let create_command = format!("create database {db_name:?}"); debug_assert!(create_command.starts_with("create database \"")); - conn.execute(&(create_command)[..]).await?; + conn.execute(AssertSqlSafe(create_command)).await?; Ok(TestContext { pool_opts: PoolOptions::new() @@ -185,7 +188,7 @@ async fn test_context(args: &TestArgs) -> Result, Error> { async fn do_cleanup(conn: &mut PgConnection, db_name: &str) -> Result<(), Error> { let delete_db_command = format!("drop database if exists {db_name:?};"); - conn.execute(&*delete_db_command).await?; + conn.execute(AssertSqlSafe(delete_db_command)).await?; query("delete from _sqlx_test.databases where db_name = $1::text") .bind(db_name) .execute(&mut *conn) diff --git a/sqlx-postgres/src/transaction.rs b/sqlx-postgres/src/transaction.rs index 7c8bd6bb1b..3f4122ea82 100644 --- a/sqlx-postgres/src/transaction.rs +++ b/sqlx-postgres/src/transaction.rs @@ -1,5 +1,5 @@ use sqlx_core::database::Database; -use std::borrow::Cow; +use sqlx_core::sql_str::SqlStr; use crate::error::Error; use crate::executor::Executor; @@ -14,11 +14,9 @@ pub struct PgTransactionManager; impl TransactionManager for PgTransactionManager { type Database = Postgres; - async fn begin( - conn: &mut PgConnection, - statement: Option>, - ) -> Result<(), Error> { + async fn begin(conn: &mut PgConnection, statement: Option) -> Result<(), Error> { let depth = conn.inner.transaction_depth; + let statement = match statement { // custom `BEGIN` statements are not allowed if we're already in // a transaction (we need to issue a `SAVEPOINT` instead) @@ -28,7 +26,7 @@ impl TransactionManager for PgTransactionManager { }; let rollback = Rollback::new(conn); - rollback.conn.queue_simple_query(&statement)?; + rollback.conn.queue_simple_query(statement.as_str())?; rollback.conn.wait_until_ready().await?; if !rollback.conn.in_transaction() { return Err(Error::BeginFailed); @@ -41,7 +39,7 @@ impl TransactionManager for PgTransactionManager { async fn commit(conn: &mut PgConnection) -> Result<(), Error> { if conn.inner.transaction_depth > 0 { - conn.execute(&*commit_ansi_transaction_sql(conn.inner.transaction_depth)) + conn.execute(commit_ansi_transaction_sql(conn.inner.transaction_depth)) .await?; conn.inner.transaction_depth -= 1; @@ -52,10 +50,8 @@ impl TransactionManager for PgTransactionManager { async fn rollback(conn: &mut PgConnection) -> Result<(), Error> { if conn.inner.transaction_depth > 0 { - conn.execute(&*rollback_ansi_transaction_sql( - conn.inner.transaction_depth, - )) - .await?; + conn.execute(rollback_ansi_transaction_sql(conn.inner.transaction_depth)) + .await?; conn.inner.transaction_depth -= 1; } @@ -65,8 +61,10 @@ impl TransactionManager for PgTransactionManager { fn start_rollback(conn: &mut PgConnection) { if conn.inner.transaction_depth > 0 { - conn.queue_simple_query(&rollback_ansi_transaction_sql(conn.inner.transaction_depth)) - .expect("BUG: Rollback query somehow too large for protocol"); + conn.queue_simple_query( + rollback_ansi_transaction_sql(conn.inner.transaction_depth).as_str(), + ) + .expect("BUG: Rollback query somehow too large for protocol"); conn.inner.transaction_depth -= 1; } diff --git a/sqlx-sqlite/src/any.rs b/sqlx-sqlite/src/any.rs index d4c3d05595..50f1bc7f72 100644 --- a/sqlx-sqlite/src/any.rs +++ b/sqlx-sqlite/src/any.rs @@ -1,5 +1,3 @@ -use std::borrow::Cow; - use crate::{ Either, Sqlite, SqliteArgumentValue, SqliteArguments, SqliteColumn, SqliteConnectOptions, SqliteConnection, SqliteQueryResult, SqliteRow, SqliteTransactionManager, SqliteTypeInfo, @@ -12,6 +10,7 @@ use sqlx_core::any::{ Any, AnyArguments, AnyColumn, AnyConnectOptions, AnyConnectionBackend, AnyQueryResult, AnyRow, AnyStatement, AnyTypeInfo, AnyTypeInfoKind, AnyValueKind, }; +use sqlx_core::sql_str::SqlStr; use crate::type_info::DataType; use sqlx_core::connection::{ConnectOptions, Connection}; @@ -40,10 +39,7 @@ impl AnyConnectionBackend for SqliteConnection { Connection::ping(self).boxed() } - fn begin( - &mut self, - statement: Option>, - ) -> BoxFuture<'_, sqlx_core::Result<()>> { + fn begin(&mut self, statement: Option) -> BoxFuture<'_, sqlx_core::Result<()>> { SqliteTransactionManager::begin(self, statement).boxed() } @@ -84,7 +80,7 @@ impl AnyConnectionBackend for SqliteConnection { fn fetch_many<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxStream<'q, sqlx_core::Result>> { @@ -107,7 +103,7 @@ impl AnyConnectionBackend for SqliteConnection { fn fetch_optional<'q>( &'q mut self, - query: &'q str, + query: SqlStr, persistent: bool, arguments: Option>, ) -> BoxFuture<'q, sqlx_core::Result>> { @@ -132,16 +128,17 @@ impl AnyConnectionBackend for SqliteConnection { fn prepare_with<'c, 'q: 'c>( &'c mut self, - sql: &'q str, + sql: SqlStr, _parameters: &[AnyTypeInfo], - ) -> BoxFuture<'c, sqlx_core::Result>> { + ) -> BoxFuture<'c, sqlx_core::Result> { Box::pin(async move { let statement = Executor::prepare_with(self, sql, &[]).await?; - AnyStatement::try_from_statement(sql, &statement, statement.column_names.clone()) + let column_names = statement.column_names.clone(); + AnyStatement::try_from_statement(statement, column_names) }) } - fn describe<'q>(&'q mut self, sql: &'q str) -> BoxFuture<'q, sqlx_core::Result>> { + fn describe(&mut self, sql: SqlStr) -> BoxFuture<'_, sqlx_core::Result>> { Box::pin(async move { Executor::describe(self, sql).await?.try_into_any() }) } } diff --git a/sqlx-sqlite/src/connection/describe.rs b/sqlx-sqlite/src/connection/describe.rs index 6db81374aa..400c671d96 100644 --- a/sqlx-sqlite/src/connection/describe.rs +++ b/sqlx-sqlite/src/connection/describe.rs @@ -5,14 +5,18 @@ use crate::error::Error; use crate::statement::VirtualStatement; use crate::type_info::DataType; use crate::{Sqlite, SqliteColumn}; +use sqlx_core::sql_str::SqlStr; use sqlx_core::Either; use std::convert::identity; -pub(crate) fn describe(conn: &mut ConnectionState, query: &str) -> Result, Error> { +pub(crate) fn describe( + conn: &mut ConnectionState, + query: SqlStr, +) -> Result, Error> { // describing a statement from SQLite can be involved // each SQLx statement is comprised of multiple SQL statements - let mut statement = VirtualStatement::new(query, false)?; + let mut statement = VirtualStatement::new(query.as_str(), false)?; let mut columns = Vec::new(); let mut nullable = Vec::new(); diff --git a/sqlx-sqlite/src/connection/execute.rs b/sqlx-sqlite/src/connection/execute.rs index 8a76236977..7acbc91ff8 100644 --- a/sqlx-sqlite/src/connection/execute.rs +++ b/sqlx-sqlite/src/connection/execute.rs @@ -3,12 +3,13 @@ use crate::error::Error; use crate::logger::QueryLogger; use crate::statement::{StatementHandle, VirtualStatement}; use crate::{SqliteArguments, SqliteQueryResult, SqliteRow}; +use sqlx_core::sql_str::SqlSafeStr; use sqlx_core::Either; pub struct ExecuteIter<'a> { handle: &'a mut ConnectionHandle, statement: &'a mut VirtualStatement, - logger: QueryLogger<'a>, + logger: QueryLogger, args: Option>, /// since a `VirtualStatement` can encompass multiple actual statements, @@ -20,12 +21,13 @@ pub struct ExecuteIter<'a> { pub(crate) fn iter<'a>( conn: &'a mut ConnectionState, - query: &'a str, + query: impl SqlSafeStr, args: Option>, persistent: bool, ) -> Result, Error> { + let query = query.into_sql_str(); // fetch the cached statement or allocate a new one - let statement = conn.statements.get(query, persistent)?; + let statement = conn.statements.get(query.as_str(), persistent)?; let logger = QueryLogger::new(query, conn.log_settings.clone()); diff --git a/sqlx-sqlite/src/connection/executor.rs b/sqlx-sqlite/src/connection/executor.rs index 1f6ce7726f..0bc88cf14e 100644 --- a/sqlx-sqlite/src/connection/executor.rs +++ b/sqlx-sqlite/src/connection/executor.rs @@ -7,6 +7,7 @@ use futures_util::{stream, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use sqlx_core::describe::Describe; use sqlx_core::error::Error; use sqlx_core::executor::{Execute, Executor}; +use sqlx_core::sql_str::SqlStr; use sqlx_core::Either; use std::{future, pin::pin}; @@ -23,12 +24,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = match query.take_arguments().map_err(Error::Encode) { Ok(arguments) => arguments, Err(error) => return stream::once(future::ready(Err(error))).boxed(), }; let persistent = query.persistent() && arguments.is_some(); + let sql = query.sql(); Box::pin( self.worker @@ -48,7 +49,6 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { 'q: 'e, E: 'q, { - let sql = query.sql(); let arguments = match query.take_arguments().map_err(Error::Encode) { Ok(arguments) => arguments, Err(error) => return future::ready(Err(error)).boxed(), @@ -56,6 +56,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { let persistent = query.persistent() && arguments.is_some(); Box::pin(async move { + let sql = query.sql(); let mut stream = pin!(self .worker .execute(sql, arguments, self.row_channel_size, persistent, Some(1)) @@ -72,29 +73,26 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { }) } - fn prepare_with<'e, 'q: 'e>( + fn prepare_with<'e>( self, - sql: &'q str, + sql: SqlStr, _parameters: &[SqliteTypeInfo], - ) -> BoxFuture<'e, Result, Error>> + ) -> BoxFuture<'e, Result> where 'c: 'e, { Box::pin(async move { let statement = self.worker.prepare(sql).await?; - Ok(SqliteStatement { - sql: sql.into(), - ..statement - }) + Ok(statement) }) } #[doc(hidden)] - fn describe<'e, 'q: 'e>(self, sql: &'q str) -> BoxFuture<'e, Result, Error>> + fn describe<'e>(self, sql: SqlStr) -> BoxFuture<'e, Result, Error>> where 'c: 'e, { - Box::pin(self.worker.describe(sql)) + Box::pin(async move { self.worker.describe(sql).await }) } } diff --git a/sqlx-sqlite/src/connection/explain.rs b/sqlx-sqlite/src/connection/explain.rs index bfa66aa12f..edd65ece49 100644 --- a/sqlx-sqlite/src/connection/explain.rs +++ b/sqlx-sqlite/src/connection/explain.rs @@ -12,6 +12,7 @@ use crate::from_row::FromRow; use crate::logger::{BranchParent, BranchResult, DebugDiff}; use crate::type_info::DataType; use crate::SqliteTypeInfo; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::{hash_map, HashMap}; use std::fmt::Debug; use std::str::from_utf8; @@ -567,7 +568,7 @@ pub(super) fn explain( ) -> Result<(Vec, Vec>), Error> { let root_block_cols = root_block_columns(conn)?; let program: Vec<(i64, String, i64, i64, i64, Vec)> = - execute::iter(conn, &format!("EXPLAIN {query}"), None, false)? + execute::iter(conn, AssertSqlSafe(format!("EXPLAIN {query}")), None, false)? .filter_map(|res| res.map(|either| either.right()).transpose()) .map(|row| FromRow::from_row(&row?)) .collect::, Error>>()?; diff --git a/sqlx-sqlite/src/connection/mod.rs b/sqlx-sqlite/src/connection/mod.rs index 3b70ed27e1..1483f2c07c 100644 --- a/sqlx-sqlite/src/connection/mod.rs +++ b/sqlx-sqlite/src/connection/mod.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::cmp::Ordering; use std::ffi::CStr; use std::fmt::Write; @@ -22,6 +21,7 @@ use sqlx_core::common::StatementCache; pub(crate) use sqlx_core::connection::*; use sqlx_core::error::Error; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; use sqlx_core::transaction::Transaction; use crate::connection::establish::EstablishParams; @@ -222,7 +222,7 @@ impl Connection for SqliteConnection { write!(pragma_string, "PRAGMA analysis_limit = {limit}; ").ok(); } pragma_string.push_str("PRAGMA optimize;"); - self.execute(&*pragma_string).await?; + self.execute(AssertSqlSafe(pragma_string)).await?; } let shutdown = self.worker.shutdown(); // Drop the statement worker, which should @@ -250,12 +250,12 @@ impl Connection for SqliteConnection { fn begin_with( &mut self, - statement: impl Into>, + statement: impl SqlSafeStr, ) -> impl Future, Error>> + Send + '_ where Self: Sized, { - Transaction::begin(self, Some(statement.into())) + Transaction::begin(self, Some(statement.into_sql_str())) } fn cached_statements_size(&self) -> usize { diff --git a/sqlx-sqlite/src/connection/worker.rs b/sqlx-sqlite/src/connection/worker.rs index ec8b38f0f6..1be624b7c4 100644 --- a/sqlx-sqlite/src/connection/worker.rs +++ b/sqlx-sqlite/src/connection/worker.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::future::Future; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -6,6 +5,7 @@ use std::thread; use futures_channel::oneshot; use futures_intrusive::sync::{Mutex, MutexGuard}; +use sqlx_core::sql_str::SqlStr; use tracing::span::Span; use sqlx_core::describe::Describe; @@ -53,15 +53,15 @@ impl WorkerSharedState { enum Command { Prepare { - query: Box, - tx: oneshot::Sender, Error>>, + query: SqlStr, + tx: oneshot::Sender>, }, Describe { - query: Box, + query: SqlStr, tx: oneshot::Sender, Error>>, }, Execute { - query: Box, + query: SqlStr, arguments: Option>, persistent: bool, tx: flume::Sender, Error>>, @@ -79,7 +79,7 @@ enum Command { }, Begin { tx: rendezvous_oneshot::Sender>, - statement: Option>, + statement: Option, }, Commit { tx: rendezvous_oneshot::Sender>, @@ -145,7 +145,7 @@ impl ConnectionWorker { let _guard = span.enter(); match cmd { Command::Prepare { query, tx } => { - tx.send(prepare(&mut conn, &query)).ok(); + tx.send(prepare(&mut conn, query)).ok(); // This may issue an unnecessary write on failure, // but it doesn't matter in the grand scheme of things. @@ -155,7 +155,7 @@ impl ConnectionWorker { ); } Command::Describe { query, tx } => { - tx.send(describe(&mut conn, &query)).ok(); + tx.send(describe(&mut conn, query)).ok(); } Command::Execute { query, @@ -164,7 +164,7 @@ impl ConnectionWorker { tx, limit } => { - let iter = match execute::iter(&mut conn, &query, arguments, persistent) + let iter = match execute::iter(&mut conn, query, arguments, persistent) { Ok(iter) => iter, Err(e) => { @@ -225,7 +225,7 @@ impl ConnectionWorker { }; let res = conn.handle - .exec(statement) + .exec(statement.as_str()) .map(|_| { shared.transaction_depth.fetch_add(1, Ordering::Release); }); @@ -238,7 +238,7 @@ impl ConnectionWorker { // immediately otherwise it would remain started forever. if let Err(error) = conn .handle - .exec(rollback_ansi_transaction_sql(depth + 1)) + .exec(rollback_ansi_transaction_sql(depth + 1).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -256,7 +256,7 @@ impl ConnectionWorker { let res = if depth > 0 { conn.handle - .exec(commit_ansi_transaction_sql(depth)) + .exec(commit_ansi_transaction_sql(depth).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -282,7 +282,7 @@ impl ConnectionWorker { let res = if depth > 0 { conn.handle - .exec(rollback_ansi_transaction_sql(depth)) + .exec(rollback_ansi_transaction_sql(depth).as_str()) .map(|_| { shared.transaction_depth.fetch_sub(1, Ordering::Release); }) @@ -335,25 +335,19 @@ impl ConnectionWorker { establish_rx.await.map_err(|_| Error::WorkerCrashed)? } - pub(crate) async fn prepare(&mut self, query: &str) -> Result, Error> { - self.oneshot_cmd(|tx| Command::Prepare { - query: query.into(), - tx, - }) - .await? + pub(crate) async fn prepare(&mut self, query: SqlStr) -> Result { + self.oneshot_cmd(|tx| Command::Prepare { query, tx }) + .await? } - pub(crate) async fn describe(&mut self, query: &str) -> Result, Error> { - self.oneshot_cmd(|tx| Command::Describe { - query: query.into(), - tx, - }) - .await? + pub(crate) async fn describe(&mut self, query: SqlStr) -> Result, Error> { + self.oneshot_cmd(|tx| Command::Describe { query, tx }) + .await? } pub(crate) async fn execute( &mut self, - query: &str, + query: SqlStr, args: Option>, chan_size: usize, persistent: bool, @@ -364,7 +358,7 @@ impl ConnectionWorker { self.command_tx .send_async(( Command::Execute { - query: query.into(), + query, arguments: args.map(SqliteArguments::into_static), persistent, tx, @@ -378,10 +372,7 @@ impl ConnectionWorker { Ok(rx) } - pub(crate) async fn begin( - &mut self, - statement: Option>, - ) -> Result<(), Error> { + pub(crate) async fn begin(&mut self, statement: Option) -> Result<(), Error> { self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, statement }) .await? } @@ -495,9 +486,9 @@ impl ConnectionWorker { } } -fn prepare(conn: &mut ConnectionState, query: &str) -> Result, Error> { +fn prepare(conn: &mut ConnectionState, query: SqlStr) -> Result { // prepare statement object (or checkout from cache) - let statement = conn.statements.get(query, true)?; + let statement = conn.statements.get(query.as_str(), true)?; let mut parameters = 0; let mut columns = None; @@ -514,7 +505,7 @@ fn prepare(conn: &mut ConnectionState, query: &str) -> Result = SqliteArguments<'q>; type ArgumentBuffer<'q> = Vec>; - type Statement<'q> = SqliteStatement<'q>; + type Statement = SqliteStatement; const NAME: &'static str = "SQLite"; diff --git a/sqlx-sqlite/src/lib.rs b/sqlx-sqlite/src/lib.rs index e4a122b6bd..8429468be2 100644 --- a/sqlx-sqlite/src/lib.rs +++ b/sqlx-sqlite/src/lib.rs @@ -57,6 +57,7 @@ pub use options::{ }; pub use query_result::SqliteQueryResult; pub use row::SqliteRow; +use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr}; pub use statement::SqliteStatement; pub use transaction::SqliteTransactionManager; pub use type_info::SqliteTypeInfo; @@ -132,9 +133,10 @@ pub fn describe_blocking(query: &str, database_url: &str) -> Result BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // Check if the schema already exists; if so, don't error. - let schema_version: Option = - query_scalar(&format!("PRAGMA {schema_name}.schema_version")) - .fetch_optional(&mut *self) - .await?; + let schema_version: Option = query_scalar(AssertSqlSafe(format!( + "PRAGMA {schema_name}.schema_version" + ))) + .fetch_optional(&mut *self) + .await?; if schema_version.is_some() { return Ok(()); @@ -86,7 +88,7 @@ impl Migrate for SqliteConnection { ) -> BoxFuture<'e, Result<(), MigrateError>> { Box::pin(async move { // language=SQLite - self.execute(&*format!( + self.execute(AssertSqlSafe(format!( r#" CREATE TABLE IF NOT EXISTS {table_name} ( version BIGINT PRIMARY KEY, @@ -97,7 +99,7 @@ CREATE TABLE IF NOT EXISTS {table_name} ( execution_time BIGINT NOT NULL ); "# - )) + ))) .await?; Ok(()) @@ -110,9 +112,9 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQLite - let row: Option<(i64,)> = query_as(&format!( + let row: Option<(i64,)> = query_as(AssertSqlSafe(format!( "SELECT version FROM {table_name} WHERE success = false ORDER BY version LIMIT 1" - )) + ))) .fetch_optional(self) .await?; @@ -126,9 +128,9 @@ CREATE TABLE IF NOT EXISTS {table_name} ( ) -> BoxFuture<'e, Result, MigrateError>> { Box::pin(async move { // language=SQLite - let rows: Vec<(i64, Vec)> = query_as(&format!( + let rows: Vec<(i64, Vec)> = query_as(AssertSqlSafe(format!( "SELECT version, checksum FROM {table_name} ORDER BY version" - )) + ))) .fetch_all(self) .await?; @@ -167,17 +169,17 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // data lineage and debugging reasons, so it is not super important if it is lost. So we initialize it to -1 // and update it once the actual transaction completed. let _ = tx - .execute(&*migration.sql) + .execute(migration.sql.clone()) .await .map_err(|e| MigrateError::ExecuteMigration(e, migration.version))?; // language=SQL - let _ = query(&format!( + let _ = query(AssertSqlSafe(format!( r#" INSERT INTO {table_name} ( version, description, success, checksum, execution_time ) VALUES ( ?1, ?2, TRUE, ?3, -1 ) "# - )) + ))) .bind(migration.version) .bind(&*migration.description) .bind(&*migration.checksum) @@ -194,13 +196,13 @@ CREATE TABLE IF NOT EXISTS {table_name} ( // language=SQL #[allow(clippy::cast_possible_truncation)] - let _ = query(&format!( + let _ = query(AssertSqlSafe(format!( r#" UPDATE {table_name} SET execution_time = ?1 WHERE version = ?2 "# - )) + ))) .bind(elapsed.as_nanos() as i64) .bind(migration.version) .execute(self) @@ -221,13 +223,15 @@ CREATE TABLE IF NOT EXISTS {table_name} ( let mut tx = self.begin().await?; let start = Instant::now(); - let _ = tx.execute(&*migration.sql).await?; + let _ = tx.execute(migration.sql.clone()).await?; // language=SQLite - let _ = query(&format!(r#"DELETE FROM {table_name} WHERE version = ?1"#)) - .bind(migration.version) - .execute(&mut *tx) - .await?; + let _ = query(AssertSqlSafe(format!( + r#"DELETE FROM {table_name} WHERE version = ?1"# + ))) + .bind(migration.version) + .execute(&mut *tx) + .await?; tx.commit().await?; diff --git a/sqlx-sqlite/src/options/connect.rs b/sqlx-sqlite/src/options/connect.rs index 111598b5fb..3f147981d1 100644 --- a/sqlx-sqlite/src/options/connect.rs +++ b/sqlx-sqlite/src/options/connect.rs @@ -3,6 +3,7 @@ use log::LevelFilter; use sqlx_core::connection::ConnectOptions; use sqlx_core::error::Error; use sqlx_core::executor::Executor; +use sqlx_core::sql_str::AssertSqlSafe; use std::fmt::Write; use std::str::FromStr; use std::time::Duration; @@ -34,7 +35,7 @@ impl ConnectOptions for SqliteConnectOptions { let mut conn = SqliteConnection::establish(self).await?; // Execute PRAGMAs - conn.execute(&*self.pragma_string()).await?; + conn.execute(AssertSqlSafe(self.pragma_string())).await?; if !self.collations.is_empty() { let mut locked = conn.lock_handle().await?; diff --git a/sqlx-sqlite/src/statement/mod.rs b/sqlx-sqlite/src/statement/mod.rs index 179b8eeaf7..ff7d841ab1 100644 --- a/sqlx-sqlite/src/statement/mod.rs +++ b/sqlx-sqlite/src/statement/mod.rs @@ -2,8 +2,8 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::{Sqlite, SqliteArguments, SqliteColumn, SqliteTypeInfo}; +use sqlx_core::sql_str::SqlStr; use sqlx_core::{Either, HashMap}; -use std::borrow::Cow; use std::sync::Arc; pub(crate) use sqlx_core::statement::*; @@ -17,26 +17,21 @@ pub(crate) use r#virtual::VirtualStatement; #[derive(Debug, Clone)] #[allow(clippy::rc_buffer)] -pub struct SqliteStatement<'q> { - pub(crate) sql: Cow<'q, str>, +pub struct SqliteStatement { + pub(crate) sql: SqlStr, pub(crate) parameters: usize, pub(crate) columns: Arc>, pub(crate) column_names: Arc>, } -impl<'q> Statement<'q> for SqliteStatement<'q> { +impl Statement for SqliteStatement { type Database = Sqlite; - fn to_owned(&self) -> SqliteStatement<'static> { - SqliteStatement::<'static> { - sql: Cow::Owned(self.sql.clone().into_owned()), - parameters: self.parameters, - columns: Arc::clone(&self.columns), - column_names: Arc::clone(&self.column_names), - } + fn into_sql(self) -> SqlStr { + self.sql } - fn sql(&self) -> &str { + fn sql(&self) -> &SqlStr { &self.sql } @@ -51,8 +46,8 @@ impl<'q> Statement<'q> for SqliteStatement<'q> { impl_statement_query!(SqliteArguments<'_>); } -impl ColumnIndex> for &'_ str { - fn index(&self, statement: &SqliteStatement<'_>) -> Result { +impl ColumnIndex for &'_ str { + fn index(&self, statement: &SqliteStatement) -> Result { statement .column_names .get(*self) diff --git a/sqlx-sqlite/src/transaction.rs b/sqlx-sqlite/src/transaction.rs index 9cfaa98d61..145999ff2d 100644 --- a/sqlx-sqlite/src/transaction.rs +++ b/sqlx-sqlite/src/transaction.rs @@ -1,7 +1,7 @@ -use std::{borrow::Cow, future::Future}; +use std::future::Future; -use sqlx_core::error::Error; use sqlx_core::transaction::TransactionManager; +use sqlx_core::{error::Error, sql_str::SqlStr}; use crate::{Sqlite, SqliteConnection}; @@ -11,10 +11,7 @@ pub struct SqliteTransactionManager; impl TransactionManager for SqliteTransactionManager { type Database = Sqlite; - async fn begin( - conn: &mut SqliteConnection, - statement: Option>, - ) -> Result<(), Error> { + async fn begin(conn: &mut SqliteConnection, statement: Option) -> Result<(), Error> { let is_custom_statement = statement.is_some(); conn.worker.begin(statement).await?; if is_custom_statement { diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index 015ff04d93..0aefbf626b 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -108,13 +108,14 @@ macro_rules! test_unprepared_type { #[sqlx_macros::test] async fn [< test_unprepared_type_ $name >] () -> anyhow::Result<()> { use sqlx::prelude::*; + use sqlx_core::sql_str::AssertSqlSafe; use futures_util::TryStreamExt; let mut conn = sqlx_test::new::<$db>().await?; $( let query = format!("SELECT {}", $text); - let mut s = conn.fetch(&*query); + let mut s = conn.fetch(AssertSqlSafe(query)); let row = s.try_next().await?.unwrap(); let rec = row.try_get::<$ty, _>(0)?; @@ -137,13 +138,14 @@ macro_rules! __test_prepared_decode_type { #[sqlx_macros::test] async fn [< test_prepared_decode_type_ $name >] () -> anyhow::Result<()> { use sqlx::Row; + use sqlx_core::sql_str::AssertSqlSafe; let mut conn = sqlx_test::new::<$db>().await?; $( let query = format!("SELECT {}", $text); - let row = sqlx::query(&query) + let row = sqlx::query(AssertSqlSafe(query)) .fetch_one(&mut conn) .await?; @@ -166,6 +168,7 @@ macro_rules! __test_prepared_type { #[sqlx_macros::test] async fn [< test_prepared_type_ $name >] () -> anyhow::Result<()> { use sqlx::Row; + use sqlx_core::sql_str::AssertSqlSafe; let mut conn = sqlx_test::new::<$db>().await?; @@ -173,7 +176,7 @@ macro_rules! __test_prepared_type { let query = format!($sql, $text); println!("{query}"); - let row = sqlx::query(&query) + let row = sqlx::query(AssertSqlSafe(query)) .bind($value) .bind($value) .fetch_one(&mut conn) diff --git a/src/lib.rs b/src/lib.rs index fa4b8d0061..5a21c3c98a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ pub use sqlx_core::query_scalar::query_scalar_with_result as __query_scalar_with pub use sqlx_core::query_scalar::{query_scalar, query_scalar_with}; pub use sqlx_core::raw_sql::{raw_sql, RawSql}; pub use sqlx_core::row::Row; +pub use sqlx_core::sql_str::{AssertSqlSafe, SqlSafeStr, SqlStr}; pub use sqlx_core::statement::Statement; pub use sqlx_core::transaction::Transaction; pub use sqlx_core::type_info::TypeInfo; diff --git a/tests/any/any.rs b/tests/any/any.rs index 099ff7ddff..62dc20403e 100644 --- a/tests/any/any.rs +++ b/tests/any/any.rs @@ -1,5 +1,6 @@ use sqlx::any::AnyRow; use sqlx::{Any, Connection, Executor, Row}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_test::new; #[sqlx_macros::test] @@ -106,7 +107,7 @@ async fn it_can_fail_and_recover() -> anyhow::Result<()> { // now try and use the connection let val: i32 = conn - .fetch_one(&*format!("SELECT {i}")) + .fetch_one(AssertSqlSafe(format!("SELECT {i}"))) .await? .get_unchecked(0); @@ -132,7 +133,7 @@ async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { // now try and use the connection let val: i32 = pool - .fetch_one(&*format!("SELECT {i}")) + .fetch_one(AssertSqlSafe(format!("SELECT {i}"))) .await? .get_unchecked(0); diff --git a/tests/any/pool.rs b/tests/any/pool.rs index 3130b4f1c6..a4849940b8 100644 --- a/tests/any/pool.rs +++ b/tests/any/pool.rs @@ -1,5 +1,6 @@ use sqlx::any::{AnyConnectOptions, AnyPoolOptions}; use sqlx::Executor; +use sqlx_core::sql_str::AssertSqlSafe; use std::sync::{ atomic::{AtomicI32, AtomicUsize, Ordering}, Arc, Mutex, @@ -111,7 +112,7 @@ async fn test_pool_callbacks() -> anyhow::Result<()> { id ); - conn.execute(&statement[..]).await?; + conn.execute(AssertSqlSafe(statement)).await?; Ok(()) }) }) diff --git a/tests/mysql/describe.rs b/tests/mysql/describe.rs index feb18dc5c6..d50c86a93a 100644 --- a/tests/mysql/describe.rs +++ b/tests/mysql/describe.rs @@ -1,12 +1,12 @@ use sqlx::mysql::MySql; -use sqlx::{Column, Executor, Type, TypeInfo}; +use sqlx::{Column, Executor, SqlSafeStr, Type, TypeInfo}; use sqlx_test::new; #[sqlx_macros::test] async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT * FROM tweet").await?; + let d = conn.describe("SELECT * FROM tweet".into_sql_str()).await?; assert_eq!(d.columns()[0].name(), "id"); assert_eq!(d.columns()[1].name(), "created_at"); @@ -43,7 +43,9 @@ CREATE TEMPORARY TABLE with_bit_and_tinyint ( ) .await?; - let d = conn.describe("SELECT * FROM with_bit_and_tinyint").await?; + let d = conn + .describe("SELECT * FROM with_bit_and_tinyint".into_sql_str()) + .await?; assert_eq!(d.column(2).name(), "value_bool"); assert_eq!(d.column(2).type_info().name(), "BOOLEAN"); @@ -62,7 +64,7 @@ async fn uses_alias_name() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("SELECT text AS tweet_text FROM tweet") + .describe("SELECT text AS tweet_text FROM tweet".into_sql_str()) .await?; assert_eq!(d.columns()[0].name(), "tweet_text"); diff --git a/tests/mysql/mysql.rs b/tests/mysql/mysql.rs index c7f7a47960..8337aacb29 100644 --- a/tests/mysql/mysql.rs +++ b/tests/mysql/mysql.rs @@ -1,7 +1,7 @@ use anyhow::Context; use futures_util::TryStreamExt; use sqlx::mysql::{MySql, MySqlConnection, MySqlPool, MySqlPoolOptions, MySqlRow}; -use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; +use sqlx::{Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo}; use sqlx_core::connection::ConnectOptions; use sqlx_mysql::MySqlConnectOptions; use sqlx_test::{new, setup_if_needed}; @@ -391,7 +391,9 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { .await? .last_insert_id(); - let statement = tx.prepare("SELECT * FROM tweet WHERE id = ?").await?; + let statement = tx + .prepare("SELECT * FROM tweet WHERE id = ?".into_sql_str()) + .await?; assert_eq!(statement.column(0).name(), "id"); assert_eq!(statement.column(1).name(), "created_at"); diff --git a/tests/postgres/derives.rs b/tests/postgres/derives.rs index f87bc7d134..86eac02065 100644 --- a/tests/postgres/derives.rs +++ b/tests/postgres/derives.rs @@ -1,6 +1,7 @@ use futures_util::TryStreamExt; use sqlx::postgres::types::PgRange; use sqlx::{Connection, Executor, FromRow, Postgres}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_postgres::PgHasArrayType; use sqlx_test::{new, test_type}; use std::fmt::Debug; @@ -259,7 +260,7 @@ SELECT id, mood FROM people WHERE id = $1 let stmt = format!("SELECT id, mood FROM people WHERE id = {people_id}"); dbg!(&stmt); - let mut cursor = conn.fetch(&*stmt); + let mut cursor = conn.fetch(AssertSqlSafe(stmt)); let row = cursor.try_next().await?.unwrap(); let rec = PeopleRow::from_row(&row)?; diff --git a/tests/postgres/describe.rs b/tests/postgres/describe.rs index d128eb21b6..806d8082c8 100644 --- a/tests/postgres/describe.rs +++ b/tests/postgres/describe.rs @@ -1,11 +1,11 @@ -use sqlx::{postgres::Postgres, Column, Executor, TypeInfo}; +use sqlx::{postgres::Postgres, Column, Executor, SqlSafeStr, TypeInfo}; use sqlx_test::new; #[sqlx_macros::test] async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT * FROM tweet").await?; + let d = conn.describe("SELECT * FROM tweet".into_sql_str()).await?; assert_eq!(d.columns()[0].name(), "id"); assert_eq!(d.columns()[1].name(), "created_at"); @@ -29,7 +29,7 @@ async fn it_describes_simple() -> anyhow::Result<()> { async fn it_describes_expression() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT 1::int8 + 10").await?; + let d = conn.describe("SELECT 1::int8 + 10".into_sql_str()).await?; // ?column? will cause the macro to emit an error ad ask the user to explicitly name the type assert_eq!(d.columns()[0].name(), "?column?"); @@ -46,7 +46,9 @@ async fn it_describes_expression() -> anyhow::Result<()> { async fn it_describes_enum() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT 'open'::status as _1").await?; + let d = conn + .describe("SELECT 'open'::status as _1".into_sql_str()) + .await?; assert_eq!(d.columns()[0].name(), "_1"); @@ -66,7 +68,9 @@ async fn it_describes_enum() -> anyhow::Result<()> { async fn it_describes_record() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("SELECT (true, 10::int2)").await?; + let d = conn + .describe("SELECT (true, 10::int2)".into_sql_str()) + .await?; let ty = d.columns()[0].type_info(); assert_eq!(ty.name(), "RECORD"); @@ -79,7 +83,7 @@ async fn it_describes_composite() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("SELECT ROW('name',10,500)::inventory_item") + .describe("SELECT ROW('name',10,500)::inventory_item".into_sql_str()) .await?; let ty = d.columns()[0].type_info(); diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index d90ef11ed8..c580bb4eed 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -5,7 +5,8 @@ use sqlx::postgres::{ PgAdvisoryLock, PgConnectOptions, PgConnection, PgDatabaseError, PgErrorPosition, PgListener, PgPoolOptions, PgRow, PgSeverity, Postgres, PG_COPY_MAX_DATA_LEN, }; -use sqlx::{Column, Connection, Executor, Row, Statement, TypeInfo}; +use sqlx::{Column, Connection, Executor, Row, SqlSafeStr, Statement, TypeInfo}; +use sqlx_core::sql_str::AssertSqlSafe; use sqlx_core::{bytes::Bytes, error::BoxDynError}; use sqlx_test::{new, pool, setup_if_needed}; use std::env; @@ -219,7 +220,7 @@ CREATE TEMPORARY TABLE json_stuff (obj json, obj2 jsonb); .await?; let query = "INSERT INTO json_stuff (obj, obj2) VALUES ($1, $2)"; - let _ = conn.describe(query).await?; + let _ = conn.describe(query.into_sql_str()).await?; let done = sqlx::query(query) .bind(serde_json::json!({ "a": "a" })) @@ -309,7 +310,10 @@ async fn it_can_fail_and_recover() -> anyhow::Result<()> { assert!(res.is_err()); // now try and use the connection - let val: i32 = conn.fetch_one(&*format!("SELECT {i}::int4")).await?.get(0); + let val: i32 = conn + .fetch_one(AssertSqlSafe(format!("SELECT {i}::int4"))) + .await? + .get(0); assert_eq!(val, i); } @@ -330,7 +334,10 @@ async fn it_can_fail_and_recover_with_pool() -> anyhow::Result<()> { assert!(res.is_err()); // now try and use the connection - let val: i32 = pool.fetch_one(&*format!("SELECT {i}::int4")).await?.get(0); + let val: i32 = pool + .fetch_one(AssertSqlSafe(format!("SELECT {i}::int4"))) + .await? + .get(0); assert_eq!(val, i); } @@ -803,7 +810,7 @@ async fn it_closes_statement_from_cache_issue_470() -> anyhow::Result<()> { let mut conn = PgConnection::connect_with(&options).await?; for i in 0..5 { - let row = sqlx::query(&*format!("SELECT {i}::int4 AS val")) + let row = sqlx::query(AssertSqlSafe(format!("SELECT {i}::int4 AS val"))) .fetch_one(&mut conn) .await?; @@ -874,7 +881,9 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { .fetch_one(&mut *tx) .await?; - let statement = tx.prepare("SELECT * FROM tweet WHERE id = $1").await?; + let statement = tx + .prepare("SELECT * FROM tweet WHERE id = $1".into_sql_str()) + .await?; assert_eq!(statement.column(0).name(), "id"); assert_eq!(statement.column(1).name(), "created_at"); @@ -960,7 +969,8 @@ async fn test_describe_outer_join_nullable() -> anyhow::Result<()> { .describe( "select tweet.id from tweet - inner join products on products.name = tweet.text", + inner join products on products.name = tweet.text" + .into_sql_str(), ) .await?; @@ -971,7 +981,8 @@ async fn test_describe_outer_join_nullable() -> anyhow::Result<()> { .describe( "select tweet.id from (values (null)) vals(val) - left join tweet on false", + left join tweet on false" + .into_sql_str(), ) .await?; @@ -985,7 +996,8 @@ from (values (null)) vals(val) .describe( "select tweet1.id, tweet2.id from tweet tweet1 - left join tweet tweet2 on false", + left join tweet tweet2 on false" + .into_sql_str(), ) .await?; @@ -998,7 +1010,8 @@ from (values (null)) vals(val) .describe( "select tweet1.id, tweet2.id from tweet tweet1 - right join tweet tweet2 on false", + right join tweet tweet2 on false" + .into_sql_str(), ) .await?; @@ -1011,7 +1024,8 @@ from (values (null)) vals(val) .describe( "select tweet1.id, tweet2.id from tweet tweet1 - full join tweet tweet2 on false", + full join tweet tweet2 on false" + .into_sql_str(), ) .await?; @@ -1120,8 +1134,10 @@ async fn test_listener_try_recv_buffered() -> anyhow::Result<()> { { let mut txn = notify_conn.begin().await?; for i in 0..5 { - txn.execute(format!("NOTIFY test_channel2, 'payload {i}'").as_str()) - .await?; + txn.execute(AssertSqlSafe(format!( + "NOTIFY test_channel2, 'payload {i}'" + ))) + .await?; } txn.commit().await?; } @@ -1972,7 +1988,8 @@ async fn test_postgres_bytea_hex_deserialization_errors() -> anyhow::Result<()> conn.execute("SET bytea_output = 'escape';").await?; for value in ["", "DEADBEEF"] { let query = format!("SELECT '\\x{value}'::bytea"); - let res: sqlx::Result> = conn.fetch_one(query.as_str()).await?.try_get(0usize); + let res: sqlx::Result> = + conn.fetch_one(AssertSqlSafe(query)).await?.try_get(0usize); // Deserialization only supports hex format so this should error and definitely not panic. res.unwrap_err(); } diff --git a/tests/postgres/query_builder.rs b/tests/postgres/query_builder.rs index cdec136976..5b73bcff35 100644 --- a/tests/postgres/query_builder.rs +++ b/tests/postgres/query_builder.rs @@ -54,18 +54,20 @@ fn test_build() { qb.push(" WHERE id = ").push_bind(42i32); let query = qb.build(); - assert_eq!(query.sql(), "SELECT * FROM users WHERE id = $1"); assert!(Execute::persistent(&query)); + assert_eq!(query.sql(), "SELECT * FROM users WHERE id = $1"); } #[test] fn test_reset() { let mut qb: QueryBuilder<'_, Postgres> = QueryBuilder::new(""); - let _query = qb - .push("SELECT * FROM users WHERE id = ") - .push_bind(42i32) - .build(); + { + let _query = qb + .push("SELECT * FROM users WHERE id = ") + .push_bind(42i32) + .build(); + } qb.reset(); @@ -97,8 +99,10 @@ fn test_query_builder_with_args() { .push_bind(42i32) .build(); + let args = query.take_arguments().unwrap().unwrap(); + let mut qb: QueryBuilder<'_, Postgres> = - QueryBuilder::with_arguments(query.sql(), query.take_arguments().unwrap().unwrap()); + QueryBuilder::with_arguments(query.sql().as_str(), args); let query = qb.push(" OR membership_level = ").push_bind(3i32).build(); assert_eq!( diff --git a/tests/postgres/rustsec.rs b/tests/postgres/rustsec.rs index 45fd76b9db..a0692be4c6 100644 --- a/tests/postgres/rustsec.rs +++ b/tests/postgres/rustsec.rs @@ -1,4 +1,5 @@ use sqlx::{Error, PgPool}; +use sqlx_core::sql_str::AssertSqlSafe; use std::{cmp, str}; @@ -114,7 +115,7 @@ async fn rustsec_2024_0363(pool: PgPool) -> anyhow::Result<()> { assert_eq!(wrapped_len, fake_payload_len); - let res = sqlx::raw_sql(&query) + let res = sqlx::raw_sql(AssertSqlSafe(query)) // Note: the connection may hang afterward // because `pending_ready_for_query_count` will underflow. .execute(&pool) diff --git a/tests/sqlite/describe.rs b/tests/sqlite/describe.rs index 5458eaaa82..4c0768a5e2 100644 --- a/tests/sqlite/describe.rs +++ b/tests/sqlite/describe.rs @@ -1,8 +1,8 @@ use sqlx::error::DatabaseError; use sqlx::sqlite::{SqliteConnectOptions, SqliteError}; -use sqlx::ConnectOptions; use sqlx::TypeInfo; use sqlx::{sqlite::Sqlite, Column, Executor}; +use sqlx::{ConnectOptions, SqlSafeStr}; use sqlx_test::new; use std::env; @@ -10,7 +10,7 @@ use std::env; async fn it_describes_simple() -> anyhow::Result<()> { let mut conn = new::().await?; - let info = conn.describe("SELECT * FROM tweet").await?; + let info = conn.describe("SELECT * FROM tweet".into_sql_str()).await?; let columns = info.columns(); assert_eq!(columns[0].name(), "id"); @@ -41,13 +41,15 @@ async fn it_describes_variables() -> anyhow::Result<()> { let mut conn = new::().await?; // without any context, we resolve to NULL - let info = conn.describe("SELECT ?1").await?; + let info = conn.describe("SELECT ?1".into_sql_str()).await?; assert_eq!(info.columns()[0].type_info().name(), "NULL"); assert_eq!(info.nullable(0), Some(true)); // nothing prevents the value from being bound to null // context can be provided by using CAST(_ as _) - let info = conn.describe("SELECT CAST(?1 AS REAL)").await?; + let info = conn + .describe("SELECT CAST(?1 AS REAL)".into_sql_str()) + .await?; assert_eq!(info.columns()[0].type_info().name(), "REAL"); assert_eq!(info.nullable(0), Some(true)); // nothing prevents the value from being bound to null @@ -60,7 +62,7 @@ async fn it_describes_expression() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("SELECT 1 + 10, 5.12 * 2, 'Hello', x'deadbeef', null") + .describe("SELECT 1 + 10, 5.12 * 2, 'Hello', x'deadbeef', null".into_sql_str()) .await?; let columns = d.columns(); @@ -107,7 +109,7 @@ async fn it_describes_temporary_table() -> anyhow::Result<()> { .await?; let d = conn - .describe("SELECT * FROM empty_all_types_and_nulls") + .describe("SELECT * FROM empty_all_types_and_nulls".into_sql_str()) .await?; assert_eq!(d.columns().len(), 8); @@ -146,7 +148,7 @@ async fn it_describes_expression_from_empty_table() -> anyhow::Result<()> { .await?; let d = conn - .describe("SELECT COUNT(*), a + 1, name, 5.12, 'Hello' FROM _temp_empty") + .describe("SELECT COUNT(*), a + 1, name, 5.12, 'Hello' FROM _temp_empty".into_sql_str()) .await?; assert_eq!(d.columns()[0].type_info().name(), "INTEGER"); @@ -175,7 +177,7 @@ async fn it_describes_expression_from_empty_table_with_star() -> anyhow::Result< .await?; let d = conn - .describe("SELECT *, 5, 'Hello' FROM _temp_empty") + .describe("SELECT *, 5, 'Hello' FROM _temp_empty".into_sql_str()) .await?; assert_eq!(d.columns()[0].type_info().name(), "TEXT"); @@ -191,13 +193,16 @@ async fn it_describes_insert() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')") + .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')".into_sql_str()) .await?; assert_eq!(d.columns().len(), 0); let d = conn - .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello'); SELECT last_insert_rowid();") + .describe( + "INSERT INTO tweet (id, text) VALUES (2, 'Hello'); SELECT last_insert_rowid();" + .into_sql_str(), + ) .await?; assert_eq!(d.columns().len(), 1); @@ -217,7 +222,7 @@ async fn it_describes_insert_with_read_only() -> anyhow::Result<()> { let mut conn = options.connect().await?; let d = conn - .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')") + .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello')".into_sql_str()) .await?; assert_eq!(d.columns().len(), 0); @@ -230,7 +235,7 @@ async fn it_describes_insert_with_returning() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello') RETURNING *") + .describe("INSERT INTO tweet (id, text) VALUES (2, 'Hello') RETURNING *".into_sql_str()) .await?; assert_eq!(d.columns().len(), 4); @@ -240,7 +245,9 @@ async fn it_describes_insert_with_returning() -> anyhow::Result<()> { assert_eq!(d.nullable(1), Some(false)); let d = conn - .describe("INSERT INTO accounts (name, is_active) VALUES ('a', true) RETURNING id") + .describe( + "INSERT INTO accounts (name, is_active) VALUES ('a', true) RETURNING id".into_sql_str(), + ) .await?; assert_eq!(d.columns().len(), 1); @@ -254,7 +261,7 @@ async fn it_describes_insert_with_returning() -> anyhow::Result<()> { async fn it_describes_bound_columns_non_null() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("INSERT INTO tweet (id, text) VALUES ($1, $2) returning *") + .describe("INSERT INTO tweet (id, text) VALUES ($1, $2) returning *".into_sql_str()) .await?; assert_eq!(d.columns().len(), 4); @@ -271,7 +278,7 @@ async fn it_describes_update_with_returning() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("UPDATE accounts SET is_active=true WHERE name=?1 RETURNING id") + .describe("UPDATE accounts SET is_active=true WHERE name=?1 RETURNING id".into_sql_str()) .await?; assert_eq!(d.columns().len(), 1); @@ -279,7 +286,7 @@ async fn it_describes_update_with_returning() -> anyhow::Result<()> { assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("UPDATE accounts SET is_active=true WHERE id=?1 RETURNING *") + .describe("UPDATE accounts SET is_active=true WHERE id=?1 RETURNING *".into_sql_str()) .await?; assert_eq!(d.columns().len(), 3); @@ -291,7 +298,7 @@ async fn it_describes_update_with_returning() -> anyhow::Result<()> { //assert_eq!(d.nullable(2), Some(false)); //query analysis is allowed to notice that it is always set to true by the update let d = conn - .describe("UPDATE accounts SET is_active=true WHERE id=?1 RETURNING id") + .describe("UPDATE accounts SET is_active=true WHERE id=?1 RETURNING id".into_sql_str()) .await?; assert_eq!(d.columns().len(), 1); @@ -306,7 +313,7 @@ async fn it_describes_delete_with_returning() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("DELETE FROM accounts WHERE name=?1 RETURNING id") + .describe("DELETE FROM accounts WHERE name=?1 RETURNING id".into_sql_str()) .await?; assert_eq!(d.columns().len(), 1); @@ -320,7 +327,10 @@ async fn it_describes_delete_with_returning() -> anyhow::Result<()> { async fn it_describes_bad_statement() -> anyhow::Result<()> { let mut conn = new::().await?; - let err = conn.describe("SELECT 1 FROM not_found").await.unwrap_err(); + let err = conn + .describe("SELECT 1 FROM not_found".into_sql_str()) + .await + .unwrap_err(); let err = err .as_database_error() .unwrap() @@ -336,13 +346,18 @@ async fn it_describes_bad_statement() -> anyhow::Result<()> { async fn it_describes_left_join() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("select accounts.id from accounts").await?; + let d = conn + .describe("select accounts.id from accounts".into_sql_str()) + .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("select tweet.id from accounts left join tweet on owner_id = accounts.id") + .describe( + "select tweet.id from accounts left join tweet on owner_id = accounts.id" + .into_sql_str(), + ) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); @@ -350,7 +365,8 @@ async fn it_describes_left_join() -> anyhow::Result<()> { let d = conn .describe( - "select tweet.id, accounts.id from accounts left join tweet on owner_id = accounts.id", + "select tweet.id, accounts.id from accounts left join tweet on owner_id = accounts.id" + .into_sql_str(), ) .await?; @@ -362,7 +378,8 @@ async fn it_describes_left_join() -> anyhow::Result<()> { let d = conn .describe( - "select tweet.id, accounts.id from accounts inner join tweet on owner_id = accounts.id", + "select tweet.id, accounts.id from accounts inner join tweet on owner_id = accounts.id" + .into_sql_str(), ) .await?; @@ -374,7 +391,8 @@ async fn it_describes_left_join() -> anyhow::Result<()> { let d = conn .describe( - "select tweet.id, accounts.id from accounts left join tweet on tweet.id = accounts.id", + "select tweet.id, accounts.id from accounts left join tweet on tweet.id = accounts.id" + .into_sql_str(), ) .await?; @@ -391,18 +409,20 @@ async fn it_describes_left_join() -> anyhow::Result<()> { async fn it_describes_group_by() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("select id from accounts group by id").await?; + let d = conn + .describe("select id from accounts group by id".into_sql_str()) + .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("SELECT name from accounts GROUP BY 1 LIMIT -1 OFFSET 1") + .describe("SELECT name from accounts GROUP BY 1 LIMIT -1 OFFSET 1".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "TEXT"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("SELECT sum(id), sum(is_sent) from tweet GROUP BY owner_id") + .describe("SELECT sum(id), sum(is_sent) from tweet GROUP BY owner_id".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); @@ -416,16 +436,20 @@ async fn it_describes_group_by() -> anyhow::Result<()> { async fn it_describes_ungrouped_aggregate() -> anyhow::Result<()> { let mut conn = new::().await?; - let d = conn.describe("select count(1) from accounts").await?; + let d = conn + .describe("select count(1) from accounts".into_sql_str()) + .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); - let d = conn.describe("SELECT sum(is_sent) from tweet").await?; + let d = conn + .describe("SELECT sum(is_sent) from tweet".into_sql_str()) + .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(true)); let d = conn - .describe("SELECT coalesce(sum(is_sent),0) from tweet") + .describe("SELECT coalesce(sum(is_sent),0) from tweet".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); @@ -437,9 +461,9 @@ async fn it_describes_ungrouped_aggregate() -> anyhow::Result<()> { async fn it_describes_literal_subquery() -> anyhow::Result<()> { async fn assert_literal_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(false), "{query}"); @@ -473,9 +497,9 @@ async fn it_describes_literal_subquery() -> anyhow::Result<()> { async fn assert_tweet_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; let columns = info.columns(); assert_eq!(columns[0].name(), "id", "{query}"); @@ -533,9 +557,9 @@ async fn it_describes_table_order_by() -> anyhow::Result<()> { async fn assert_literal_order_by_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(false), "{query}"); @@ -571,9 +595,9 @@ async fn it_describes_table_order_by() -> anyhow::Result<()> { async fn it_describes_union() -> anyhow::Result<()> { async fn assert_union_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, ) -> anyhow::Result<()> { - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(false), "{query}"); @@ -638,7 +662,8 @@ async fn it_describes_having_group_by() -> anyhow::Result<()> { ) single_reply_count FROM accounts WHERE id = ?1 - "#, + "# + .into_sql_str(), ) .await?; @@ -653,11 +678,11 @@ async fn it_describes_having_group_by() -> anyhow::Result<()> { async fn it_describes_strange_queries() -> anyhow::Result<()> { async fn assert_single_column_described( conn: &mut sqlx::SqliteConnection, - query: &str, + query: &'static str, typename: &str, nullable: bool, ) -> anyhow::Result<()> { - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), typename, "{query}"); assert_eq!(info.nullable(0), Some(nullable), "{query}"); @@ -759,22 +784,22 @@ async fn it_describes_func_date() -> anyhow::Result<()> { let mut conn = new::().await?; let query = "SELECT date();"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(false), "{query}"); let query = "SELECT date('now');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT date('now', 'start of month');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT date(:datebind);"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); Ok(()) @@ -785,22 +810,22 @@ async fn it_describes_func_time() -> anyhow::Result<()> { let mut conn = new::().await?; let query = "SELECT time();"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(false), "{query}"); let query = "SELECT time('now');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT time('now', 'start of month');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT time(:datebind);"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); Ok(()) @@ -811,22 +836,22 @@ async fn it_describes_func_datetime() -> anyhow::Result<()> { let mut conn = new::().await?; let query = "SELECT datetime();"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(false), "{query}"); let query = "SELECT datetime('now');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT datetime('now', 'start of month');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT datetime(:datebind);"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); Ok(()) @@ -837,22 +862,22 @@ async fn it_describes_func_julianday() -> anyhow::Result<()> { let mut conn = new::().await?; let query = "SELECT julianday();"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "REAL", "{query}"); assert_eq!(info.nullable(0), Some(false), "{query}"); let query = "SELECT julianday('now');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "REAL", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT julianday('now', 'start of month');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "REAL", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT julianday(:datebind);"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "REAL", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); Ok(()) @@ -863,17 +888,17 @@ async fn it_describes_func_strftime() -> anyhow::Result<()> { let mut conn = new::().await?; let query = "SELECT strftime('%s','now');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT strftime('%s', 'now', 'start of month');"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); //can't prove that it's not-null yet let query = "SELECT strftime('%s',:datebind);"; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); Ok(()) @@ -897,7 +922,7 @@ async fn it_describes_with_recursive() -> anyhow::Result<()> { FROM schedule GROUP BY begin_date "; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); @@ -915,7 +940,7 @@ async fn it_describes_with_recursive() -> anyhow::Result<()> { FROM schedule GROUP BY begin_date "; - let info = conn.describe(query).await?; + let info = conn.describe(query.into_sql_str()).await?; assert_eq!(info.column(0).type_info().name(), "TEXT", "{query}"); assert_eq!(info.nullable(0), Some(true), "{query}"); @@ -927,97 +952,99 @@ async fn it_describes_analytical_function() -> anyhow::Result<()> { let mut conn = new::().await?; let d = conn - .describe("select row_number() over () from accounts") + .describe("select row_number() over () from accounts".into_sql_str()) .await?; dbg!(&d); assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); - let d = conn.describe("select rank() over () from accounts").await?; + let d = conn + .describe("select rank() over () from accounts".into_sql_str()) + .await?; dbg!(&d); assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("select dense_rank() over () from accounts") + .describe("select dense_rank() over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("select percent_rank() over () from accounts") + .describe("select percent_rank() over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "REAL"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("select cume_dist() over () from accounts") + .describe("select cume_dist() over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "REAL"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("select ntile(1) over () from accounts") + .describe("select ntile(1) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("select lag(id) over () from accounts") + .describe("select lag(id) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(true)); let d = conn - .describe("select lag(name) over () from accounts") + .describe("select lag(name) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "TEXT"); assert_eq!(d.nullable(0), Some(true)); let d = conn - .describe("select lead(id) over () from accounts") + .describe("select lead(id) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(true)); let d = conn - .describe("select lead(name) over () from accounts") + .describe("select lead(name) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "TEXT"); assert_eq!(d.nullable(0), Some(true)); let d = conn - .describe("select first_value(id) over () from accounts") + .describe("select first_value(id) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(true)); let d = conn - .describe("select first_value(name) over () from accounts") + .describe("select first_value(name) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "TEXT"); assert_eq!(d.nullable(0), Some(true)); let d = conn - .describe("select last_value(id) over () from accounts") + .describe("select last_value(id) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(false)); let d = conn - .describe("select first_value(name) over () from accounts") + .describe("select first_value(name) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "TEXT"); //assert_eq!(d.nullable(0), Some(false)); //this should be null, but it's hard to prove that it will be let d = conn - .describe("select nth_value(id,10) over () from accounts") + .describe("select nth_value(id,10) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "INTEGER"); assert_eq!(d.nullable(0), Some(true)); let d = conn - .describe("select nth_value(name,10) over () from accounts") + .describe("select nth_value(name,10) over () from accounts".into_sql_str()) .await?; assert_eq!(d.column(0).type_info().name(), "TEXT"); assert_eq!(d.nullable(0), Some(true)); diff --git a/tests/sqlite/rustsec.rs b/tests/sqlite/rustsec.rs index 3ff9c524fa..08f88a3ad9 100644 --- a/tests/sqlite/rustsec.rs +++ b/tests/sqlite/rustsec.rs @@ -1,4 +1,4 @@ -use sqlx::{Connection, Error, SqliteConnection}; +use sqlx::{AssertSqlSafe, Connection, Error, SqliteConnection}; // https://rustsec.org/advisories/RUSTSEC-2024-0363.html // @@ -50,7 +50,7 @@ async fn rustsec_2024_0363() -> anyhow::Result<()> { .execute(&mut conn) .await?; - let res = sqlx::raw_sql(&query).execute(&mut conn).await; + let res = sqlx::raw_sql(AssertSqlSafe(query)).execute(&mut conn).await; if let Err(e) = res { // Connection rejected the query; we're happy. diff --git a/tests/sqlite/sqlite.rs b/tests/sqlite/sqlite.rs index 8f8f269580..b9bc6320f8 100644 --- a/tests/sqlite/sqlite.rs +++ b/tests/sqlite/sqlite.rs @@ -2,6 +2,7 @@ use futures_util::TryStreamExt; use rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256PlusPlus; use sqlx::sqlite::{SqliteConnectOptions, SqliteOperation, SqlitePoolOptions}; +use sqlx::SqlSafeStr; use sqlx::{ query, sqlite::Sqlite, sqlite::SqliteRow, Column, ConnectOptions, Connection, Executor, Row, SqliteConnection, SqlitePool, Statement, TypeInfo, @@ -504,7 +505,9 @@ async fn it_can_prepare_then_execute() -> anyhow::Result<()> { let tweet_id: i32 = 2; - let statement = tx.prepare("SELECT * FROM tweet WHERE id = ?1").await?; + let statement = tx + .prepare("SELECT * FROM tweet WHERE id = ?1".into_sql_str()) + .await?; assert_eq!(statement.column(0).name(), "id"); assert_eq!(statement.column(1).name(), "text");