Skip to content

Commit

Permalink
fix(dgw): enforce recording policy (#906)
Browse files Browse the repository at this point in the history
When recording flag is set and recording stream is closed, the associated
session is killed within 10 seconds.

Issue: DGW-86
  • Loading branch information
CBenoit authored Jul 1, 2024
1 parent 9adaa8d commit 13ed397
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 20 deletions.
3 changes: 0 additions & 3 deletions devolutions-gateway/src/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,6 @@ where
)
.await?;

// NOTE(DGW-86): when recording is required, should we wait for it to start before we forward, or simply spawn
// a timer to check if the recording is started within a few seconds?

let kill_notified = notify_kill.notified();

let res = if let Some(buffer_size) = self.buffer_size {
Expand Down
61 changes: 58 additions & 3 deletions devolutions-gateway/src/recording.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use tokio::{fs, io};
use typed_builder::TypedBuilder;
use uuid::Uuid;

use crate::session::SessionMessageSender;
use crate::token::{JrecTokenClaims, RecordingFileType};

const DISCONNECTED_TTL_SECS: i64 = 10;
Expand Down Expand Up @@ -162,6 +163,7 @@ struct OnGoingRecording {
state: OnGoingRecordingState,
manifest: JrecManifest,
manifest_path: Utf8PathBuf,
session_must_be_recorded: bool,
}

enum RecordingManagerMessage {
Expand Down Expand Up @@ -310,14 +312,20 @@ pub struct RecordingManagerTask {
rx: RecordingMessageReceiver,
ongoing_recordings: HashMap<Uuid, OnGoingRecording>,
recordings_path: Utf8PathBuf,
session_manager_handle: SessionMessageSender,
}

impl RecordingManagerTask {
pub fn new(rx: RecordingMessageReceiver, recordings_path: Utf8PathBuf) -> Self {
pub fn new(
rx: RecordingMessageReceiver,
recordings_path: Utf8PathBuf,
session_manager_handle: SessionMessageSender,
) -> Self {
Self {
rx,
ongoing_recordings: HashMap::new(),
recordings_path,
session_manager_handle,
}
}

Expand Down Expand Up @@ -389,12 +397,26 @@ impl RecordingManagerTask {

let active_recording_count = self.rx.active_recordings.insert(id);

// NOTE: the session associated to this recording is not always running through the Devolutions Gateway.
// It is a normal situation when the Devolutions is used solely as a recording server.
// In such cases, we can only assume there is no recording policy.
let session_must_be_recorded = self
.session_manager_handle
.get_session_info(id)
.await
.inspect_err(|error| error!(%error, session.id = %id, "Failed to retrieve session info"))
.ok()
.flatten()
.map(|info| info.recording_policy)
.unwrap_or(false);

self.ongoing_recordings.insert(
id,
OnGoingRecording {
state: OnGoingRecordingState::Connected,
manifest,
manifest_path,
session_must_be_recorded,
},
);
let ongoing_recording_count = self.ongoing_recordings.len();
Expand Down Expand Up @@ -453,9 +475,42 @@ impl RecordingManagerTask {
OnGoingRecordingState::LastSeen { timestamp } if now >= timestamp + DISCONNECTED_TTL_SECS - 1 => {
debug!(%id, "Mark recording as terminated");
self.rx.active_recordings.remove(id);
self.ongoing_recordings.remove(&id);

// TODO(DGW-86): now is a good timing to kill sessions that _must_ be recorded
// Check the recording policy of the associated session and kill it if necessary.
if ongoing.session_must_be_recorded {
tokio::spawn({
let session_manager_handle = self.session_manager_handle.clone();

async move {
let result = session_manager_handle.kill_session(id).await;

match result {
Ok(crate::session::KillResult::Success) => {
warn!(
session.id = %id,
reason = "recording policy violated",
"Session killed",
);
}
Ok(crate::session::KillResult::NotFound) => {
trace!(
session.id = %id,
"Associated session is not running, as expected",
);
}
Err(error) => {
error!(
session.id = %id,
%error,
"Couldn’t kill session",
)
}
}
}
});
}

self.ongoing_recordings.remove(&id);
}
_ => {
trace!(%id, "Recording should not be removed yet");
Expand Down
7 changes: 5 additions & 2 deletions devolutions-gateway/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result<Tasks> {
sessions: session_manager_handle.clone(),
subscriber_tx: subscriber_tx.clone(),
shutdown_signal: tasks.shutdown_signal.clone(),
recordings: recording_manager_handle,
recordings: recording_manager_handle.clone(),
};

conf.listeners
Expand Down Expand Up @@ -243,7 +243,7 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result<Tasks> {
));

tasks.register(devolutions_gateway::subscriber::SubscriberPollingTask {
sessions: session_manager_handle,
sessions: session_manager_handle.clone(),
subscriber: subscriber_tx,
});

Expand All @@ -253,12 +253,15 @@ async fn spawn_tasks(conf_handle: ConfHandle) -> anyhow::Result<Tasks> {
});

tasks.register(devolutions_gateway::session::SessionManagerTask::new(
session_manager_handle.clone(),
session_manager_rx,
recording_manager_handle,
));

tasks.register(devolutions_gateway::recording::RecordingManagerTask::new(
recording_manager_rx,
conf.recording_path.clone(),
session_manager_handle,
));

Ok(tasks)
Expand Down
128 changes: 117 additions & 11 deletions devolutions-gateway/src/session.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::recording::RecordingMessageSender;
use crate::subscriber;
use crate::target_addr::TargetAddr;
use crate::token::{ApplicationProtocol, SessionTtl};
Expand Down Expand Up @@ -133,6 +134,10 @@ enum SessionManagerMessage {
info: SessionInfo,
notify_kill: Arc<Notify>,
},
GetInfo {
id: Uuid,
channel: oneshot::Sender<Option<SessionInfo>>,
},
Remove {
id: Uuid,
channel: oneshot::Sender<Option<SessionInfo>>,
Expand All @@ -155,6 +160,9 @@ impl fmt::Debug for SessionManagerMessage {
SessionManagerMessage::New { info, notify_kill: _ } => {
f.debug_struct("New").field("info", info).finish_non_exhaustive()
}
SessionManagerMessage::GetInfo { id, channel: _ } => {
f.debug_struct("GetInfo").field("id", id).finish_non_exhaustive()
}
SessionManagerMessage::Remove { id, channel: _ } => {
f.debug_struct("Remove").field("id", id).finish_non_exhaustive()
}
Expand All @@ -179,6 +187,16 @@ impl SessionMessageSender {
.context("couldn't send New message")
}

pub async fn get_session_info(&self, id: Uuid) -> anyhow::Result<Option<SessionInfo>> {
let (tx, rx) = oneshot::channel();
self.0
.send(SessionManagerMessage::GetInfo { id, channel: tx })
.await
.ok()
.context("couldn't send Remove message")?;
rx.await.context("couldn't receive info for session")
}

pub async fn remove_session(&self, id: Uuid) -> anyhow::Result<Option<SessionInfo>> {
let (tx, rx) = oneshot::channel();
self.0
Expand Down Expand Up @@ -256,26 +274,48 @@ impl Ord for WithTtlInfo {
}

pub struct SessionManagerTask {
tx: SessionMessageSender,
rx: SessionMessageReceiver,
all_running: RunningSessions,
all_notify_kill: HashMap<Uuid, Arc<Notify>>,
recording_manager_handle: RecordingMessageSender,
}

impl SessionManagerTask {
pub fn new(rx: SessionMessageReceiver) -> Self {
pub fn init(recording_manager_handle: RecordingMessageSender) -> Self {
let (tx, rx) = session_manager_channel();

Self::new(tx, rx, recording_manager_handle)
}

pub fn new(
tx: SessionMessageSender,
rx: SessionMessageReceiver,
recording_manager_handle: RecordingMessageSender,
) -> Self {
Self {
tx,
rx,
all_running: HashMap::new(),
all_notify_kill: HashMap::new(),
recording_manager_handle,
}
}

pub fn handle(&self) -> SessionMessageSender {
self.tx.clone()
}

fn handle_new(&mut self, info: SessionInfo, notify_kill: Arc<Notify>) {
let id = info.association_id;
self.all_running.insert(id, info);
self.all_notify_kill.insert(id, notify_kill);
}

fn handle_get_info(&mut self, id: Uuid) -> Option<SessionInfo> {
self.all_running.get(&id).cloned()
}

fn handle_remove(&mut self, id: Uuid) -> Option<SessionInfo> {
let removed_session = self.all_running.remove(&id);
let _ = self.all_notify_kill.remove(&id);
Expand Down Expand Up @@ -312,17 +352,14 @@ async fn session_manager_task(
debug!("Task started");

let mut with_ttl = BinaryHeap::<WithTtlInfo>::new();

let auto_kill_sleep = tokio::time::sleep_until(tokio::time::Instant::now());
tokio::pin!(auto_kill_sleep);

// Consume initial sleep
(&mut auto_kill_sleep).await;
(&mut auto_kill_sleep).await; // Consume initial sleep.

loop {
tokio::select! {
() = &mut auto_kill_sleep, if !with_ttl.is_empty() => {
// Will never panic since we check for non-emptiness before entering this block
// Will never panic since we check for non-emptiness before entering this block.
let to_kill = with_ttl.pop().unwrap();

match manager.handle_kill(to_kill.session_id) {
Expand All @@ -334,7 +371,7 @@ async fn session_manager_task(
}
}

// Re-arm the Sleep instance with the next deadline if required
// Re-arm the Sleep instance with the next deadline if required.
if let Some(next) = with_ttl.peek() {
auto_kill_sleep.as_mut().reset(next.deadline)
}
Expand All @@ -350,24 +387,41 @@ async fn session_manager_task(
match msg {
SessionManagerMessage::New { info, notify_kill } => {
if let SessionTtl::Limited { minutes } = info.time_to_live {
let duration = Duration::from_secs(minutes.get() * 60);
let now = tokio::time::Instant::now();
let duration = Duration::from_secs(minutes.get() * 60);
let deadline = now + duration;

with_ttl.push(WithTtlInfo {
deadline,
session_id: info.id(),
});

// Reset the Sleep instance if the new deadline is sooner or it is already elapsed
// Reset the Sleep instance if the new deadline is sooner or it is already elapsed.
if auto_kill_sleep.is_elapsed() || deadline < auto_kill_sleep.deadline() {
auto_kill_sleep.as_mut().reset(deadline);
}

debug!(session.id = %info.id(), minutes = minutes.get(), "Limited TTL session registed");
debug!(session.id = %info.id(), minutes = minutes.get(), "Limited TTL session registered");
}

if info.recording_policy {
let task = EnsureRecordingPolicyTask {
session_id: info.id(),
session_manager_handle: manager.tx.clone(),
recording_manager_handle: manager.recording_manager_handle.clone(),
};

devolutions_gateway_task::spawn_task(task, shutdown_signal.clone()).detach();

debug!(session.id = %info.id(), "Session with recording policy registered");
}

manager.handle_new(info, notify_kill);
},
}
SessionManagerMessage::GetInfo { id, channel } => {
let session_info = manager.handle_get_info(id);
let _ = channel.send(session_info);
}
SessionManagerMessage::Remove { id, channel } => {
let removed_session = manager.handle_remove(id);
let _ = channel.send(removed_session);
Expand Down Expand Up @@ -416,3 +470,55 @@ async fn session_manager_task(

Ok(())
}

struct EnsureRecordingPolicyTask {
session_id: Uuid,
session_manager_handle: SessionMessageSender,
recording_manager_handle: RecordingMessageSender,
}

#[async_trait]
impl Task for EnsureRecordingPolicyTask {
type Output = ();

const NAME: &'static str = "ensure recording policy";

async fn run(self, mut shutdown_signal: ShutdownSignal) -> Self::Output {
use futures::future::Either;
use std::pin::pin;

let sleep = tokio::time::sleep(Duration::from_secs(10));
let shutdown_signal = shutdown_signal.wait();

match futures::future::select(pin!(sleep), pin!(shutdown_signal)).await {
Either::Left(_) => {}
Either::Right(_) => return,
}

let is_not_recording = self
.recording_manager_handle
.get_state(self.session_id)
.await
.ok()
.flatten()
.is_none();

if is_not_recording {
match self.session_manager_handle.kill_session(self.session_id).await {
Ok(KillResult::Success) => {
warn!(
session.id = %self.session_id,
reason = "recording policy violated",
"Session killed",
);
}
Ok(KillResult::NotFound) => {
trace!(session.id = %self.session_id, "Session already ended");
}
Err(error) => {
debug!(session.id = %self.session_id, error = format!("{error:#}"), "Couldn’t kill the session");
}
}
}
}
}
Loading

0 comments on commit 13ed397

Please sign in to comment.