Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions sqlx-postgres/src/connection/establish.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,21 @@ impl PgConnection {
("TimeZone", "UTC"),
];

if let Some(ref extra_float_digits) = options.extra_float_digits {
if let Some(extra_float_digits) = options.get_extra_float_digits() {
params.push(("extra_float_digits", extra_float_digits));
}

if let Some(ref application_name) = options.application_name {
if let Some(application_name) = options.get_application_name() {
params.push(("application_name", application_name));
}

if let Some(ref options) = options.options {
if let Some(options) = options.get_options() {
params.push(("options", options));
}

stream.write(Startup {
username: Some(&options.username),
database: options.database.as_deref(),
username: Some(options.get_username()),
database: Some(options.get_database()),
params: &params,
})?;

Expand Down Expand Up @@ -77,7 +77,7 @@ impl PgConnection {

stream
.send(Password::Cleartext(
options.password.as_deref().unwrap_or_default(),
options.get_password().unwrap_or_default(),
))
.await?;
}
Expand All @@ -90,8 +90,8 @@ impl PgConnection {

stream
.send(Password::Md5 {
username: &options.username,
password: options.password.as_deref().unwrap_or_default(),
username: options.get_username(),
password: options.get_password().unwrap_or_default(),
salt: body.salt,
})
.await?;
Expand Down
4 changes: 2 additions & 2 deletions sqlx-postgres/src/connection/sasl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ pub(crate) async fn authenticate(
BASE64_STANDARD.encode_string(GS2_HEADER, &mut channel_binding);

// "n=" saslname ;; Usernames are prepared using SASLprep.
let username = format!("{}={}", USERNAME_ATTR, options.username);
let username = format!("{}={}", USERNAME_ATTR, options.get_username());
let username = match saslprep(&username) {
Ok(v) => v,
// TODO(danielakhterov): Remove panic when we have proper support for configuration errors
Expand Down Expand Up @@ -87,7 +87,7 @@ pub(crate) async fn authenticate(

// SaltedPassword := Hi(Normalize(password), salt, i)
let salted_password = hi(
options.password.as_deref().unwrap_or_default(),
options.get_password().unwrap_or_default(),
&cont.salt,
cont.iterations,
)?;
Expand Down
9 changes: 8 additions & 1 deletion sqlx-postgres/src/connection/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ impl PgStream {
pub(super) async fn connect(options: &PgConnectOptions) -> Result<Self, Error> {
let socket_result = match options.fetch_socket() {
Some(ref path) => net::connect_uds(path, MaybeUpgradeTls(options)).await?,
None => net::connect_tcp(&options.host, options.port, MaybeUpgradeTls(options)).await?,
None => {
net::connect_tcp(
options.get_host(),
options.get_port(),
MaybeUpgradeTls(options),
)
.await?
}
};

let socket = socket_result?;
Expand Down
8 changes: 4 additions & 4 deletions sqlx-postgres/src/connection/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn maybe_upgrade<S: Socket>(
options: &PgConnectOptions,
) -> Result<Box<dyn Socket>, Error> {
// https://www.postgresql.org/docs/12/libpq-ssl.html#LIBPQ-SSL-SSLMODE-STATEMENTS
match options.ssl_mode {
match options.get_ssl_mode() {
// FIXME: Implement ALLOW
PgSslMode::Allow | PgSslMode::Disable => return Ok(Box::new(socket)),

Expand All @@ -46,15 +46,15 @@ async fn maybe_upgrade<S: Socket>(
}

let accept_invalid_certs = !matches!(
options.ssl_mode,
options.get_ssl_mode(),
PgSslMode::VerifyCa | PgSslMode::VerifyFull
);
let accept_invalid_hostnames = !matches!(options.ssl_mode, PgSslMode::VerifyFull);
let accept_invalid_hostnames = !matches!(options.get_ssl_mode(), PgSslMode::VerifyFull);

let config = TlsConfig {
accept_invalid_certs,
accept_invalid_hostnames,
hostname: &options.host,
hostname: options.get_host(),
root_cert_path: options.ssl_root_cert.as_ref(),
client_cert_path: options.ssl_client_cert.as_ref(),
client_key_path: options.ssl_client_key.as_ref(),
Expand Down
18 changes: 6 additions & 12 deletions sqlx-postgres/src/migrate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,19 @@ use crate::query_scalar::query_scalar;
use crate::{PgConnectOptions, PgConnection, Postgres};

fn parse_for_maintenance(url: &str) -> Result<(PgConnectOptions, String), Error> {
let mut options = PgConnectOptions::from_str(url)?;
let options = PgConnectOptions::from_str(url)?;

// pull out the name of the database to create
let database = options
.database
.as_deref()
.unwrap_or(&options.username)
.to_owned();
let database = options.get_database().to_owned();

// switch us to the maintenance database
// use `postgres` _unless_ the database is postgres, in which case, use `template1`
// this matches the behavior of the `createdb` util
options.database = if database == "postgres" {
Some("template1".into())
if database == "postgres" {
Ok((options.database("template1"), database))
} else {
Some("postgres".into())
};

Ok((options, database))
Ok((options.database("postgres"), database))
}
}

impl MigrateDatabase for Postgres {
Expand Down
47 changes: 38 additions & 9 deletions sqlx-postgres/src/options/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,10 @@ impl PgConnectOptions {
pub(crate) fn apply_pgpass(mut self) -> Self {
if self.password.is_none() {
self.password = pgpass::load_password(
&self.host,
self.port,
&self.username,
self.database.as_deref(),
self.get_host(),
self.get_port(),
self.get_username(),
self.get_database(),
);
}

Expand Down Expand Up @@ -519,18 +519,34 @@ impl PgConnectOptions {
&self.username
}

/// Get the password.
///
/// ```rust
/// # use sqlx_postgres::PgConnectOptions;
/// let options = PgConnectOptions::new()
/// .password("53C237");
/// assert_eq!(options.get_password(), Some("53C237"));
/// ```
pub fn get_password(&self) -> Option<&str> {
self.password.as_deref()
}

/// Get the current database name.
///
/// Defaults to username if not given.
///
/// # Example
///
/// ```rust
/// # use sqlx_postgres::PgConnectOptions;
/// let options = PgConnectOptions::new()
/// .database("postgres");
/// assert!(options.get_database().is_some());
/// let options = PgConnectOptions::new().database("postgres");
/// assert_eq!(options.get_database(), "postgres");
///
/// let options = PgConnectOptions::new().username("alice");
/// assert_eq!(options.get_database(), "alice");
/// ```
pub fn get_database(&self) -> Option<&str> {
self.database.as_deref()
pub fn get_database(&self) -> &str {
self.database.as_deref().unwrap_or(&self.username)
}

/// Get the SSL mode.
Expand Down Expand Up @@ -560,6 +576,19 @@ impl PgConnectOptions {
self.application_name.as_deref()
}

/// Get the extra float digits.
///
/// # Example
///
/// ```rust
/// # use sqlx_postgres::PgConnectOptions;
/// let options = PgConnectOptions::new();
/// assert_eq!(options.get_extra_float_digits(), Some("2"));
/// ```
pub fn get_extra_float_digits(&self) -> std::option::Option<&str> {
self.extra_float_digits.as_deref()
}

/// Get the options.
///
/// # Example
Expand Down
53 changes: 13 additions & 40 deletions sqlx-postgres/src/options/pgpass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,7 @@ use std::io::{BufRead, BufReader};
use std::path::PathBuf;

/// try to load a password from the various pgpass file locations
pub fn load_password(
host: &str,
port: u16,
username: &str,
database: Option<&str>,
) -> Option<String> {
pub fn load_password(host: &str, port: u16, username: &str, database: &str) -> Option<String> {
let custom_file = var_os("PGPASSFILE");
if let Some(file) = custom_file {
if let Some(password) =
Expand Down Expand Up @@ -39,7 +34,7 @@ fn load_password_from_file(
host: &str,
port: u16,
username: &str,
database: Option<&str>,
database: &str,
) -> Option<String> {
let file = File::open(&path)
.map_err(|e| {
Expand Down Expand Up @@ -88,7 +83,7 @@ fn load_password_from_reader(
host: &str,
port: u16,
username: &str,
database: Option<&str>,
database: &str,
) -> Option<String> {
let mut line = String::new();

Expand Down Expand Up @@ -129,7 +124,7 @@ fn load_password_from_line(
host: &str,
port: u16,
username: &str,
database: Option<&str>,
database: &str,
) -> Option<String> {
let whole_line = line;

Expand All @@ -140,7 +135,7 @@ fn load_password_from_line(
_ => {
matches_next_field(whole_line, &mut line, host)?;
matches_next_field(whole_line, &mut line, &port.to_string())?;
matches_next_field(whole_line, &mut line, database.unwrap_or_default())?;
matches_next_field(whole_line, &mut line, database)?;
matches_next_field(whole_line, &mut line, username)?;
Some(line.to_owned())
}
Expand Down Expand Up @@ -268,41 +263,24 @@ mod tests {
"localhost",
5432,
"foo",
Some("bar")
"bar",
),
Some("baz".to_owned())
);
// wildcard
assert_eq!(
load_password_from_line("*:5432:bar:foo:baz", "localhost", 5432, "foo", Some("bar")),
Some("baz".to_owned())
);
// accept wildcard with missing db
assert_eq!(
load_password_from_line("localhost:5432:*:foo:baz", "localhost", 5432, "foo", None),
load_password_from_line("*:5432:bar:foo:baz", "localhost", 5432, "foo", "bar"),
Some("baz".to_owned())
);

// doesn't match
assert_eq!(
load_password_from_line(
"thishost:5432:bar:foo:baz",
"thathost",
5432,
"foo",
Some("bar")
),
load_password_from_line("thishost:5432:bar:foo:baz", "thathost", 5432, "foo", "bar",),
None
);
// malformed entry
assert_eq!(
load_password_from_line(
"localhost:5432:bar:foo",
"localhost",
5432,
"foo",
Some("bar")
),
load_password_from_line("localhost:5432:bar:foo", "localhost", 5432, "foo", "bar",),
None
);
}
Expand All @@ -323,28 +301,23 @@ mod tests {

// normal
assert_eq!(
load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("bar")),
load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", "bar"),
Some("baz".to_owned())
);
// wildcard
assert_eq!(
load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", Some("foobar")),
Some("baz".to_owned())
);
// accept wildcard with missing db
assert_eq!(
load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", None),
load_password_from_reader(&mut &file[..], "localhost", 5432, "foo", "foobar"),
Some("baz".to_owned())
);

// doesn't match
assert_eq!(
load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")),
load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", "foobar"),
None
);
// malformed entry
assert_eq!(
load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", Some("foobar")),
load_password_from_reader(&mut &file[..], "thathost", 5432, "foo", "foobar"),
None
);
}
Expand Down
8 changes: 4 additions & 4 deletions sqlx-postgres/src/testing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ async fn test_context(args: &TestArgs) -> Result<TestContext<Postgres>, Error> {
Err((existing, pool)) => {
// Sanity checks.
assert_eq!(
existing.connect_options().host,
pool.connect_options().host,
existing.connect_options().get_host(),
pool.connect_options().get_host(),
"DATABASE_URL changed at runtime, host differs"
);

assert_eq!(
existing.connect_options().database,
pool.connect_options().database,
existing.connect_options().get_database(),
pool.connect_options().get_database(),
"DATABASE_URL changed at runtime, database differs"
);

Expand Down
Loading