Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: single-connection mode connection reestablishment #39

Merged
merged 8 commits into from
Sep 10, 2024
8 changes: 7 additions & 1 deletion tacacs-plus/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,13 @@ md-5 = "0.10.6"
uuid = { version = "1.10.0", features = ["v4"] }

[dev-dependencies]
tokio = { version = "1.39.1", features = ["rt", "net", "time", "macros"] }
tokio = { version = "1.39.1", features = [
"rt",
"net",
"time",
"macros",
"process",
] }
tokio-util = { version = "0.7.11", features = ["compat"] }
async-net = "2.0.0"
async-std = { version = "1.12.0", features = ["attributes"] }
163 changes: 161 additions & 2 deletions tacacs-plus/src/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ use std::fmt;
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::task::Poll;

use byteorder::{ByteOrder, NetworkEndian};
use futures::poll;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tacacs_plus_protocol::{Deserialize, PacketBody, Serialize};
use tacacs_plus_protocol::{HeaderInfo, Packet, PacketFlags};
Expand Down Expand Up @@ -119,11 +121,28 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ClientInner<S> {
Ok(conn)
}

/// Writes a packet to the underlying connection.
/// Writes a packet to the underlying connection, reconnecting if necessary.
pub(super) async fn send_packet<B: PacketBody + Serialize>(
&mut self,
packet: Packet<B>,
secret_key: Option<&[u8]>,
) -> Result<(), ClientError> {
// check if other end closed our connection, and reopen it accordingly
let connection = self.connection().await?;
if !is_connection_open(connection).await? {
self.post_session_cleanup(true).await?;
}

// send the packet after ensuring the connection is valid (or dropping
// it if it's invalid)
self._send_packet(packet, secret_key).await
}

/// Writes a packet to the underlying connection.
async fn _send_packet<B: PacketBody + Serialize>(
&mut self,
packet: Packet<B>,
secret_key: Option<&[u8]>,
) -> Result<(), ClientError> {
// allocate zero-filled buffer large enough to hold packet
let mut packet_buffer = vec![0; packet.wire_size()];
Expand Down Expand Up @@ -195,7 +214,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ClientInner<S> {
pub(super) async fn post_session_cleanup(&mut self, status_is_error: bool) -> io::Result<()> {
// close session if server doesn't agree to SINGLE_CONNECTION negotiation, or if an error occurred (since a mutex guarantees only one session is going at a time)
if !self.single_connection_established || status_is_error {
// SAFETY: ensure_connection should be called before this function, and guarantees inner.connection is non-None
// SAFETY: connection() should be called before this function, and guarantees inner.connection is non-None
let mut connection = self.connection.take().unwrap();
connection.close().await?;

Expand All @@ -212,3 +231,143 @@ impl<S: AsyncRead + AsyncWrite + Unpin> ClientInner<S> {
Ok(())
}
}

/// Checks if the provided connection is still open on both sides.
///
/// This is accomplished by attempting to read a single byte from the connection
/// and checking for an EOF condition or specific errors (broken pipe/connection reset).
///
/// This might be overkill, but during testing I encountered a case where a write succeeded
/// and a subsequent read hung due to the connection being closed on the other side, so
/// avoiding that is preferable.
async fn is_connection_open<C>(connection: &mut C) -> io::Result<bool>
where
C: AsyncRead + Unpin,
{
let mut buffer = [0];

// poll the read future exactly once to see if anything is ready immediately
match poll!(connection.read(&mut buffer)) {
// something ready on first poll likely indicates something wrong, since we aren't
// expecting any data to actually be ready
Poll::Ready(ready) => match ready {
// read of length 0 indicates an EOF, which happens when the other side closes a TCP connection
Ok(0) => Ok(false),

Err(e) => match e.kind() {
// these errors indicate that the connection is closed, which is the exact
// situation we're trying to recover from
//
// BrokenPipe seems to be Linux-specific (?), ConnectionReset is more general though
// (checked TCP & read(2) man pages for MacOS/FreeBSD/Linux)
io::ErrorKind::BrokenPipe | io::ErrorKind::ConnectionReset => Ok(false),

// bubble up any other errors to the caller
_ => Err(e),
},

// if there's data still available, the connection is still open, although
// this shouldn't happen in the context of TACACS+
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should be debug/trace log here? Would this be possible to hit in a multi-threaded context? Or would it hit the mutex lock?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be prevented by the mutex since it's locked for the full session within Client

Ok(1..) => Ok(true),
},

// nothing ready to read -> connection is still open
Poll::Pending => Ok(true),
}
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use futures::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::Notify;
use tokio_util::compat::TokioAsyncReadCompatExt;

use super::is_connection_open;

fn get_test_address() -> String {
std::env::var("TCP_TEST_ADDR").unwrap_or(String::from("localhost:9999"))
}

async fn bind_to_test_address() -> TcpListener {
let address = get_test_address();

TcpListener::bind(&address)
.await
.unwrap_or_else(|err| panic!("failed to bind to address {address}: {err:?}"))
}

#[tokio::test]
async fn connection_open_check() {
let notify = Arc::new(Notify::new());
let listener_notify = notify.clone();

tokio::spawn(async move {
let listener = bind_to_test_address().await;
listener_notify.notify_one();

let (stream, _) = listener
.accept()
.await
.expect("failed to accept connection");

let mut stream = stream.compat();
let mut buf = [0];
stream.read(&mut buf).await
});

// wait for server to bind to address
notify.notified().await;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are nice to coordinate between threads (especially in tests)


let client = TcpStream::connect(get_test_address())
.await
.expect("couldn't connect to test listener");
let mut client = client.compat();

let is_open = is_connection_open(&mut client)
.await
.expect("couldn't check if connection was open");
assert!(is_open);
}

#[tokio::test]
async fn connection_closed_check() {
let notify = Arc::new(Notify::new());
let listener_notify = notify.clone();

tokio::spawn(async move {
let listener = bind_to_test_address().await;
listener_notify.notify_one();

let (stream, _) = listener
.accept()
.await
.expect("failed to accept connection");

let mut stream = stream.compat();

// close connection & notify main test task
stream.close().await.unwrap();
listener_notify.notify_one();
});

// wait for server to bind to address
notify.notified().await;

let client = TcpStream::connect(get_test_address())
.await
.expect("couldn't connect to test listener");
let mut client = client.compat();

// let server close connection
notify.notified().await;

// ensure connection is detected as closed
let is_open = is_connection_open(&mut client)
.await
.expect("couldn't check if connection was open");
assert!(!is_open);
}
}
60 changes: 55 additions & 5 deletions tacacs-plus/tests/pap_login.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
use std::time::Duration;

use futures::{FutureExt, TryFutureExt};
use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncWriteCompatExt;
use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};

use tacacs_plus::Client;
use tacacs_plus::Client as TacacsClient;
use tacacs_plus::{AuthenticationType, ContextBuilder, ResponseStatus};

mod common;

type Client = TacacsClient<Compat<TcpStream>>;

#[tokio::test]
async fn pap_success() {
// NOTE: this assumes you have a TACACS+ server running already
Expand All @@ -24,10 +28,43 @@ async fn pap_success() {
Some(common::SECRET_KEY),
);

let context = ContextBuilder::new("someuser".to_owned()).build();
attempt_pap_login(&tac_client, "someuser".to_owned(), "hunter2").await;
}

let response = tac_client
.authenticate(context, "hunter2", AuthenticationType::Pap)
// this test is ignored since it takes a bit to run & requires specific actions to run alongside the test (restarting server)
#[tokio::test]
#[ignore]
async fn connection_reestablishment() {
let address = common::get_server_address();
let client = Client::new(
Box::new(move || {
TcpStream::connect(address.clone())
.map_ok(TokioAsyncWriteCompatExt::compat_write)
.boxed()
}),
Some(common::SECRET_KEY),
);

let user = String::from("paponly");
let password = "pass-word";
attempt_pap_login(&client, user.clone(), password).await;

// restart server container
if let Ok(container_name) = std::env::var("SERVER_CONTAINER") {
restart_server_container(container_name).await;
}

// sleep for a bit to allow server time to start back up
tokio::time::sleep(Duration::from_millis(500)).await;

// try logging in after server restart to ensure connection is reestablished
attempt_pap_login(&client, user, password).await;
}

async fn attempt_pap_login(client: &Client, user: String, password: &str) {
let context = ContextBuilder::new(user).build();
let response = client
.authenticate(context, password, AuthenticationType::Pap)
.await
.expect("error completing authentication session");

Expand All @@ -37,3 +74,16 @@ async fn pap_success() {
"authentication failed, full response: {response:?}"
);
}

async fn restart_server_container(name: String) {
let docker_command = std::env::var("docker").unwrap_or_else(|_| "docker".to_owned());

let status = tokio::process::Command::new(docker_command)
.args(["restart", &name])
.stdout(std::process::Stdio::null())
.status()
.await
.expect("couldn't get exit status of server container restart command");

assert!(status.success(), "bad exit status: {status}");
}
1 change: 1 addition & 0 deletions test-assets/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ id = tac_plus-ng {
device everything {
key = "very secure key that is super secret"
address = 0.0.0.0/0
single-connection = yes

script { rewrite user = emptyGuest }
}
Expand Down
12 changes: 8 additions & 4 deletions test-assets/run-client-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ set -euo pipefail

REPO_ROOT=$(git rev-parse --show-toplevel)
TMPDIR=$(mktemp -d)
docker=${docker:-docker}
export docker=${docker:-docker}
export SERVER_CONTAINER=tacacs-server

if [ ! -v CI ]; then
# build server image
Expand All @@ -30,21 +31,24 @@ test_against_server_image() {
echo "Testing against image: $image"

echo "Running server container in background"
$docker run --rm --detach --publish 5555:5555 --name tacacs-server "$image" >/dev/null
$docker run --rm --detach --publish 5555:5555 --name $SERVER_CONTAINER "$image" >/dev/null

# run all integration tests against server
# run integration tests against server
echo "Running tests..."
cargo test --package tacacs-plus --test '*' --no-fail-fast

# copy accounting file out of container
$docker cp tacacs-server:/tmp/accounting.log $TMPDIR/accounting.log
$docker cp $SERVER_CONTAINER:/tmp/accounting.log $TMPDIR/accounting.log

# verify contents of accounting file, printing if invalid
if ! $REPO_ROOT/test-assets/validate_accounting_file.py $TMPDIR/accounting.log; then
echo 'accounting file:'
cat $TMPDIR/accounting.log
return 1
fi

# test reconnection by restarting server mid test-run
cargo test --package tacacs-plus --test pap_login connection_reestablishment -- --ignored
}

trap "stop_running_containers" EXIT
Expand Down
Loading