Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unix domain socket support for Postgres #253

Merged
merged 6 commits into from
Apr 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ all-type = ["bigdecimal", "json", "time", "chrono", "ipnetwork", "uuid"]
# we need a feature which activates `num-bigint` as well because
# `bigdecimal` uses types from it but does not reexport (tsk tsk)
bigdecimal = ["bigdecimal_", "num-bigint"]
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ]
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink", "tokio/uds" ]
json = ["serde", "serde_json"]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
sqlite = [ "libsqlite3-sys" ]
Expand Down
28 changes: 23 additions & 5 deletions sqlx-core/src/io/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::pin::Pin;
use std::task::{Context, Poll};

use crate::runtime::{AsyncRead, AsyncWrite, TcpStream};
use crate::url::Url;

use self::Inner::*;

Expand All @@ -14,15 +13,24 @@ pub struct MaybeTlsStream {

enum Inner {
NotTls(TcpStream),
#[cfg(all(feature = "postgres", unix))]
UnixStream(crate::runtime::UnixStream),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like how this in MaybeTlsStream. I think it conflates what the type here is doing.

I'd rather something like:

  • A separate Socket (name?) type that is internally an enumeration between UnixStream and TcpStream.

  • This MaybeTlsStream is then generalized around an S type parameter.

Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see much reason to parameterize MaybeTlsStream if the type is always going to be Socket.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could call it MaybeUdsStream and reverse the relationship so that MaybeUdsStream is an enumeration between UnixStream and MaybeTlsStream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess before we get ahead of ourselves here.. is TLS over UDS even a reasonable thing? If it is, we probably need to at least support that with the type sandwhich.

Copy link
Collaborator

@abonander abonander Apr 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wording in the documentation doesn't suggest TLS support at all over the domain socket: https://www.postgresql.org/docs/12/ssl-tcp.html

It doesn't really make sense anyways to encrypt over the UDS because if someone can sniff the traffic on that socket then they can probably also just inspect your process' memory or Postgres' memory/files for juicy secrets (since they'd have to be running on the same machine with elevated privileges already).

#[cfg(feature = "tls")]
Tls(async_native_tls::TlsStream<TcpStream>),
#[cfg(feature = "tls")]
Upgrading,
}

impl MaybeTlsStream {
pub async fn connect(url: &Url, default_port: u16) -> crate::Result<Self> {
let conn = TcpStream::connect((url.host(), url.port(default_port))).await?;
#[cfg(all(feature = "postgres", unix))]
pub async fn connect_uds<S: AsRef<std::ffi::OsStr>>(p: S) -> crate::Result<Self> {
let conn = crate::runtime::UnixStream::connect(p.as_ref()).await?;
Ok(Self {
inner: Inner::UnixStream(conn),
})
}
pub async fn connect(host: &str, port: u16) -> crate::Result<Self> {
let conn = TcpStream::connect((host, port)).await?;
Ok(Self {
inner: Inner::NotTls(conn),
})
Expand All @@ -32,6 +40,8 @@ impl MaybeTlsStream {
pub fn is_tls(&self) -> bool {
match self.inner {
Inner::NotTls(_) => false,
#[cfg(all(feature = "postgres", unix))]
Inner::UnixStream(_) => false,
#[cfg(feature = "tls")]
Inner::Tls(_) => true,
#[cfg(feature = "tls")]
Expand All @@ -43,23 +53,29 @@ impl MaybeTlsStream {
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
pub async fn upgrade(
&mut self,
url: &Url,
host: &str,
connector: async_native_tls::TlsConnector,
) -> crate::Result<()> {
let conn = match std::mem::replace(&mut self.inner, Upgrading) {
NotTls(conn) => conn,
#[cfg(all(feature = "postgres", unix))]
UnixStream(_) => {
return Err(tls_err!("TLS is not supported with unix domain sockets").into())
}
Tls(_) => return Err(tls_err!("connection already upgraded").into()),
Upgrading => return Err(tls_err!("connection already failed to upgrade").into()),
};

self.inner = Tls(connector.connect(url.host(), conn).await?);
self.inner = Tls(connector.connect(host, conn).await?);

Ok(())
}

pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
match self.inner {
NotTls(ref conn) => conn.shutdown(how),
#[cfg(all(feature = "postgres", unix))]
UnixStream(ref conn) => conn.shutdown(how),
#[cfg(feature = "tls")]
Tls(ref conn) => conn.get_ref().shutdown(how),
#[cfg(feature = "tls")]
Expand All @@ -73,6 +89,8 @@ macro_rules! forward_pin (
($self:ident.$method:ident($($arg:ident),*)) => (
match &mut $self.inner {
NotTls(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(all(feature = "postgres", unix))]
UnixStream(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(feature = "tls")]
Tls(ref mut conn) => Pin::new(conn).$method($($arg),*),
#[cfg(feature = "tls")]
Expand Down
4 changes: 3 additions & 1 deletion sqlx-core/src/mysql/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ pub(crate) struct MySqlStream {

impl MySqlStream {
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
let stream = MaybeTlsStream::connect(&url, 3306).await?;
let host = url.host().unwrap_or("localhost");
let port = url.port(3306);
let stream = MaybeTlsStream::connect(host, port).await?;

let mut capabilities = Capabilities::PROTOCOL_41
| Capabilities::IGNORE_SPACE
Expand Down
5 changes: 4 additions & 1 deletion sqlx-core/src/mysql/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,8 @@ async fn try_upgrade(
)
.await?;

stream.stream.upgrade(url, connector).await
stream
.stream
.upgrade(url.host().unwrap_or("localhost"), connector)
.await
}
10 changes: 7 additions & 3 deletions sqlx-core/src/postgres/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,13 @@ pub struct PgConnection {

// https://www.postgresql.org/docs/12/protocol-flow.html#id-1.10.5.7.3
async fn startup(stream: &mut PgStream, url: &Url) -> crate::Result<BackendKeyData> {
// Defaults to postgres@.../postgres
let username = url.username().unwrap_or(Cow::Borrowed("postgres"));
let database = url.database().unwrap_or("postgres");
// Defaults to $USER@.../$USER
// and falls back to postgres@.../postgres
let username = url
.username()
.or_else(|| std::env::var("USER").map(Cow::Owned).ok())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps https://crates.io/crates/whoami is a better idea to be more resilient?

.unwrap_or(Cow::Borrowed("postgres"));
let database = url.database().unwrap_or(&username);

// See this doc for more runtime parameters
// https://www.postgresql.org/docs/12/runtime-config-client.html
Expand Down
22 changes: 21 additions & 1 deletion sqlx-core/src/postgres/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,27 @@ pub struct PgStream {

impl PgStream {
pub(super) async fn new(url: &Url) -> crate::Result<Self> {
let stream = MaybeTlsStream::connect(&url, 5432).await?;
let host = url.host();
let port = url.port(5432);
#[cfg(unix)]
let stream = {
let host = host
.map(|host| {
percent_encoding::percent_decode_str(host)
.decode_utf8()
.expect("percent-encoded hostname contained non-UTF-8 bytes")
})
.or_else(|| url.param("host"))
.unwrap_or("/var/run/postgresql".into());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By default we may want to iterate potential locations for the socket and not just hard-code this one.

  • /tmp
  • /var/run/postgresql
  • /private/tmp ( macOS )

if host.starts_with("/") {
let path = format!("{}/.s.PGSQL.{}", host, port);
MaybeTlsStream::connect_uds(&path).await?
} else {
MaybeTlsStream::connect(&host, port).await?
}
};
#[cfg(not(unix))]
let stream = MaybeTlsStream::connect(host.unwrap_or("localhost"), port).await?;

Ok(Self {
notifications: None,
Expand Down
8 changes: 7 additions & 1 deletion sqlx-core/src/postgres/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ async fn try_upgrade(
accept_invalid_host_names: bool,
) -> crate::Result<bool> {
use async_native_tls::TlsConnector;
use std::borrow::Cow;

stream.write(crate::postgres::protocol::SslRequest);
stream.flush().await?;
Expand Down Expand Up @@ -105,7 +106,12 @@ async fn try_upgrade(
}
}

stream.stream.upgrade(url, connector).await?;
let host = url
.host()
.map(Cow::Borrowed)
.or_else(|| url.param("host"))
.unwrap_or("localhost".into());
stream.stream.upgrade(&host, connector).await?;

Ok(true)
}
Expand Down
6 changes: 6 additions & 0 deletions sqlx-core/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ pub(crate) use async_std::{
task::spawn,
};

#[cfg(all(feature = "runtime-async-std", feature = "postgres", unix))]
pub(crate) use async_std::os::unix::net::UnixStream;

#[cfg(feature = "runtime-tokio")]
pub(crate) use tokio::{
fs,
Expand All @@ -26,3 +29,6 @@ pub(crate) use tokio::{
time::delay_for as sleep,
time::timeout,
};

#[cfg(all(feature = "runtime-tokio", feature = "postgres", unix))]
pub(crate) use tokio::net::UnixStream;
11 changes: 4 additions & 7 deletions sqlx-core/src/url.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,10 @@ impl Url {
self.0.as_str()
}

pub fn host(&self) -> &str {
let host = self.0.host_str();

match host {
Some(host) if !host.is_empty() => host,

_ => "localhost",
pub fn host(&self) -> Option<&str> {
match self.0.host_str()? {
"" => None,
host => Some(host),
}
}

Expand Down