Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: rfc7766 #76

Merged
merged 3 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ You should also include the user name that made the change.

## 0.4.3 (Unreleased)

- Feat: add a new feature for serving TCP queries, defined in RFC7766: Pipelining queries and connection reuse.
- Deps.
- Refactor: Various minor improvements.


## 0.4.2

- Feat: Change the default hasher for hashmaps and hashsets from `FxHash` to `aHash` for better performance with string keys. Use `ArcSwap` instead of `RwLock` for internal ODoH config storage.
Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ members = ["proxy-bin", "proxy-lib"]
resolver = "2"

[workspace.package]
version = "0.4.2"
version = "0.4.3"
authors = ["Jun Kurihara"]
homepage = "https://github.com/junkurihara/doh-auth-proxy"
repository = "https://github.com/junkurihara/doh-auth-proxy"
Expand Down
2 changes: 2 additions & 0 deletions proxy-lib/src/constants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub const UDP_CHANNEL_CAPACITY: usize = 1024; // TODO: channelキャパシティ
pub const UDP_TIMEOUT_SEC: u64 = 10;
/// TCP listen backlog
pub const TCP_LISTEN_BACKLOG: u32 = 1024;
/// TCP idle timeout in secs
pub const TCP_IDLE_TIMEOUT_SEC: u64 = 10;

/// Max connections via UPD and TCP (total) TODO: めちゃ適当
pub const MAX_CONNECTIONS: usize = 128;
Expand Down
3 changes: 3 additions & 0 deletions proxy-lib/src/globals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ pub struct ProxyConfig {
pub udp_timeout_sec: Duration,
/// TCP listen backlog
pub tcp_listen_backlog: u32,
/// TCP idle timeout
pub tcp_idle_timeout_sec: Duration,

/// timeout for HTTP requests (DoH, ODoH, and authentication requests)
pub http_timeout_sec: Duration,
Expand Down Expand Up @@ -197,6 +199,7 @@ impl Default for ProxyConfig {
udp_channel_capacity: UDP_CHANNEL_CAPACITY,
udp_timeout_sec: Duration::from_secs(UDP_TIMEOUT_SEC),
tcp_listen_backlog: TCP_LISTEN_BACKLOG,
tcp_idle_timeout_sec: Duration::from_secs(TCP_IDLE_TIMEOUT_SEC),

http_timeout_sec: Duration::from_secs(HTTP_TIMEOUT_SEC),
http_user_agent: format!("{}/{}", HTTP_USER_AGENT, env!("CARGO_PKG_VERSION")),
Expand Down
2 changes: 1 addition & 1 deletion proxy-lib/src/http_client/trait_resolve_ips.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub async fn resolve_ips(endpoints: &[Url], resolver_ips: impl ResolveIps) -> Re
let resolve_ips_fut = endpoints.iter().map(|endpoint| async {
let host_is_ipaddr = endpoint
.host_str()
.map_or(false, |host| host.parse::<std::net::IpAddr>().is_ok());
.is_some_and(|host| host.parse::<std::net::IpAddr>().is_ok());
if host_is_ipaddr {
Ok(ResolveIpResponse {
hostname: endpoint.host_str().unwrap().to_string(),
Expand Down
127 changes: 96 additions & 31 deletions proxy-lib/src/proxy/proxy_tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl Proxy {
}

/// Serve TCP query
pub async fn serve_tcp_query(self, mut stream: TcpStream, src_addr: SocketAddr) -> Result<()> {
pub async fn serve_tcp_query(self, stream: TcpStream, src_addr: SocketAddr) -> Result<()> {
debug!("handle tcp query from {:?}", src_addr);
let counter = self.counter.clone();
if counter.increment(CounterType::Tcp) >= self.globals.proxy_config.max_connections as isize {
Expand All @@ -48,43 +48,108 @@ impl Proxy {
counter.decrement(CounterType::Tcp);
return Err(Error::TooManyConnections);
}
// let doh_client = self.context.get_random_client().await?;

// read data from stream
// first 2bytes indicates the length of dns message following from the 3rd byte
let mut length_buf = [0u8; 2];
stream.read_exact(&mut length_buf).await?;
let msg_length = u16::from_be_bytes(length_buf) as usize;
if msg_length == 0 {
return Err(Error::NullTcpStream);
}
let mut packet_buf = vec![0u8; msg_length];
stream.read_exact(&mut packet_buf).await?;
let res = self.serve_tcp_query_inner(stream, src_addr).await;

// decrement counter anyways
counter.decrement(CounterType::Tcp);

res
}

// make DoH query
let res = tokio::time::timeout(
self.globals.proxy_config.http_timeout_sec + std::time::Duration::from_secs(1),
// serve tcp dns message here
self.doh_client.make_doh_query(&packet_buf, ProxyProtocol::Tcp, &src_addr),
)
.await
.ok();
// debug!("response from DoH server: {:?}", res);
/// Serve TCP query inner, supporting connection reuse and pipelining (RFC7766)
pub async fn serve_tcp_query_inner(self, stream: TcpStream, src_addr: SocketAddr) -> Result<()> {
// split stream into readable and writeable
let (mut readable_stream, mut writeable_stream) = stream.into_split();

// send response via stream
counter.decrement(CounterType::Tcp); // decrement counter anyways
/* ------- */
// spawn a task to make doh query and write response to stream
let (tx, mut rx) = tokio::sync::mpsc::channel::<Vec<u8>>(10);
let self_clone = self.clone();
self.globals.runtime_handle.spawn(async move {
while let Some(q) = rx.recv().await {
// if query is empty, connection is closed
if q.is_empty() {
debug!("TCP connection closed");
break;
}

if let Some(Ok(r)) = res {
if r.len() > (u16::MAX as usize) {
return Err(Error::InvalidDnsResponseSize);
// make DoH query
let res = tokio::time::timeout(
self_clone.globals.proxy_config.http_timeout_sec + std::time::Duration::from_secs(1),
// serve tcp dns message here
self_clone.doh_client.make_doh_query(&q, ProxyProtocol::Tcp, &src_addr),
)
.await
.ok();

// send response via stream
if let Some(Ok(r)) = res {
if r.len() > (u16::MAX as usize) {
error!("Response too large: {}", r.len());
break;
}
let length_buf = u16::to_be_bytes(r.len() as u16);
if let Err(e) = writeable_stream.write_all(&length_buf).await {
error!("Failed to write length to stream: {e}");
break;
}
if let Err(e) = writeable_stream.write_all(&r).await {
error!("Failed to write response to stream: {e}");
break;
}
} else {
error!("Failed to make DoH query");
break;
}
}
debug!("Finish serving TCP writable stream");
});
/* ------- */

// read query from readable stream
loop {
let Ok(res) = tokio::time::timeout(
self.globals.proxy_config.tcp_idle_timeout_sec,
read_query(&mut readable_stream),
)
.await
else {
debug!("TCP idle timeout or TCP connection closed");
let _ = tx.send(vec![]).await; // send empty vec to close connection
break;
};
let packet_buf = res?;
let qsize = packet_buf.len();
let _ = tx.send(packet_buf).await;
if qsize == 0 {
// connection closed
break;
}
let length_buf = u16::to_be_bytes(r.len() as u16);
stream.write_all(&length_buf).await?;
stream.write_all(&r).await?;
} else {
return Err(Error::FailedToMakeDohQuery);
}

Ok(())
}
}

/// Read query from stream, if stream is closed, return empty vec
async fn read_query(stream: &mut tokio::net::tcp::OwnedReadHalf) -> Result<Vec<u8>> {
// check if stream is closed
let x = stream.peek(&mut [0u8]).await?;
if x == 0 {
debug!("TCP connection closed");
return Ok(vec![]);
}

// read data from stream
// first 2bytes indicates the length of dns message following from the 3rd byte
let mut length_buf = [0u8; 2];
stream.read_exact(&mut length_buf).await?;
let msg_length = u16::from_be_bytes(length_buf) as usize;
if msg_length == 0 {
return Err(Error::NullTcpStream);
}
let mut packet_buf = vec![0u8; msg_length];
stream.read_exact(&mut packet_buf).await?;
Ok(packet_buf)
}
Loading