Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Configuration device plugin #627

61 changes: 37 additions & 24 deletions agent/src/util/config_action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{
DISCOVERY_OPERATOR_STOP_DISCOVERY_CHANNEL_CAPACITY,
},
device_plugin_service,
device_plugin_service::InstanceMap,
device_plugin_service::DevicePluginContext,
discovery_operator::start_discovery::{start_discovery, DiscoveryOperator},
registration::RegisteredDiscoveryHandlerMap,
};
Expand Down Expand Up @@ -32,7 +32,7 @@ type ConfigMap = Arc<RwLock<HashMap<ConfigId, ConfigInfo>>>;
#[derive(Debug)]
pub struct ConfigInfo {
/// Map of all of a Configuration's Instances
instance_map: InstanceMap,
device_plugin_context: Arc<RwLock<DevicePluginContext>>,
/// Sends notification to a `DiscoveryOperator` that it should stop all discovery for its Configuration.
/// This signals it to tell each of its subtasks to stop discovery.
/// A broadcast channel is used so both the sending and receiving ends can be cloned.
Expand Down Expand Up @@ -252,13 +252,13 @@ async fn handle_config_add(
config.metadata.name.clone().unwrap(),
);
// Create a new instance map for this config and add it to the config map
let instance_map: InstanceMap = Arc::new(RwLock::new(HashMap::new()));
let device_plugin_context = Arc::new(RwLock::new(DevicePluginContext::default()));
let (stop_discovery_sender, _): (broadcast::Sender<()>, broadcast::Receiver<()>) =
broadcast::channel(DISCOVERY_OPERATOR_STOP_DISCOVERY_CHANNEL_CAPACITY);
let (mut finished_discovery_sender, finished_discovery_receiver) =
mpsc::channel(DISCOVERY_OPERATOR_FINISHED_DISCOVERY_CHANNEL_CAPACITY);
let config_info = ConfigInfo {
instance_map: instance_map.clone(),
device_plugin_context: device_plugin_context.clone(),
stop_discovery_sender: stop_discovery_sender.clone(),
finished_discovery_receiver,
last_generation: config.metadata.generation,
Expand All @@ -269,7 +269,7 @@ async fn handle_config_add(
// Keep discovering instances until the config is deleted, signaled by a message from handle_config_delete
tokio::spawn(async move {
let discovery_operator =
DiscoveryOperator::new(discovery_handler_map, config, instance_map);
DiscoveryOperator::new(discovery_handler_map, config, device_plugin_context);
start_discovery(
discovery_operator,
new_discovery_handler_sender,
Expand Down Expand Up @@ -335,17 +335,29 @@ async fn handle_config_delete(
}

// Get map of instances for the Configuration and then remove Configuration from ConfigMap
let instance_map: InstanceMap;
let device_plugin_context;
{
let mut config_map_locked = config_map.write().await;
instance_map = config_map_locked
device_plugin_context = config_map_locked
.get(&config_id)
.unwrap()
.instance_map
.device_plugin_context
.clone();
config_map_locked.remove(&config_id);
}
delete_all_instances_in_map(kube_interface, instance_map, config_id).await?;
delete_all_instances_in_device_plugin_context(
kube_interface,
device_plugin_context.clone(),
config_id,
)
.await?;
if let Some(sender) = &device_plugin_context
.read()
.await
.usage_update_message_sender
{
sender.send(device_plugin_service::ListAndWatchMessageKind::End)?;
}
Ok(())
}

Expand Down Expand Up @@ -375,13 +387,13 @@ async fn should_recreate_config(
}

/// This shuts down all a Configuration's Instances and terminates the associated Device Plugins
pub async fn delete_all_instances_in_map(
pub async fn delete_all_instances_in_device_plugin_context(
kube_interface: &dyn k8s::KubeInterface,
instance_map: InstanceMap,
device_plugin_context: Arc<RwLock<DevicePluginContext>>,
(namespace, name): ConfigId,
) -> anyhow::Result<()> {
let mut instance_map_locked = instance_map.write().await;
let instances_to_delete_map = instance_map_locked.clone();
let mut device_plugin_context_locked = device_plugin_context.write().await;
let instances_to_delete_map = device_plugin_context_locked.clone().instances;
for (instance_name, instance_info) in instances_to_delete_map {
trace!(
"handle_config_delete - found Instance {} associated with deleted config {:?} ... sending message to end list_and_watch",
Expand All @@ -392,7 +404,9 @@ pub async fn delete_all_instances_in_map(
.list_and_watch_message_sender
.send(device_plugin_service::ListAndWatchMessageKind::End)
.unwrap();
instance_map_locked.remove(&instance_name);
device_plugin_context_locked
.instances
.remove(&instance_name);
try_delete_instance(kube_interface, &instance_name, namespace.as_str()).await?;
}
Ok(())
Expand All @@ -401,9 +415,8 @@ pub async fn delete_all_instances_in_map(
#[cfg(test)]
mod config_action_tests {
use super::super::{
device_plugin_service,
device_plugin_service::{InstanceConnectivityStatus, InstanceMap},
discovery_operator::tests::build_instance_map,
device_plugin_service, device_plugin_service::InstanceConnectivityStatus,
discovery_operator::tests::build_device_plugin_context,
};
use super::*;
use akri_shared::{akri::configuration::Configuration, k8s::MockKubeInterface};
Expand Down Expand Up @@ -468,7 +481,7 @@ mod config_action_tests {
let mut list_and_watch_message_receivers = Vec::new();
let mut visible_discovery_results = Vec::new();
let mut mock = MockKubeInterface::new();
let instance_map: InstanceMap = build_instance_map(
let device_plugin_context = build_device_plugin_context(
&config,
&mut visible_discovery_results,
&mut list_and_watch_message_receivers,
Expand All @@ -482,7 +495,7 @@ mod config_action_tests {
config_id.clone(),
ConfigInfo {
stop_discovery_sender,
instance_map: instance_map.clone(),
device_plugin_context: device_plugin_context.clone(),
finished_discovery_receiver,
last_generation: config.metadata.generation,
},
Expand Down Expand Up @@ -518,7 +531,7 @@ mod config_action_tests {
futures::future::join_all(tasks).await;

// Assert that all instances have been removed from the instance map
assert_eq!(instance_map.read().await.len(), 0);
assert_eq!(device_plugin_context.read().await.instances.len(), 0);
}

#[tokio::test]
Expand All @@ -534,7 +547,7 @@ mod config_action_tests {
let mut list_and_watch_message_receivers = Vec::new();
let mut visible_discovery_results = Vec::new();
let mut mock = MockKubeInterface::new();
let instance_map: InstanceMap = build_instance_map(
let device_plugin_context = build_device_plugin_context(
&config,
&mut visible_discovery_results,
&mut list_and_watch_message_receivers,
Expand All @@ -548,7 +561,7 @@ mod config_action_tests {
config_id.clone(),
ConfigInfo {
stop_discovery_sender,
instance_map: instance_map.clone(),
device_plugin_context: device_plugin_context.clone(),
finished_discovery_receiver,
last_generation: config.metadata.generation,
},
Expand Down Expand Up @@ -583,7 +596,7 @@ mod config_action_tests {
futures::future::join_all(tasks).await;

// Assert that all instances have been removed from the instance map
assert_eq!(instance_map.read().await.len(), 0);
assert_eq!(device_plugin_context.read().await.instances.len(), 0);
}

// Tests that when a Configuration is updated,
Expand Down Expand Up @@ -640,7 +653,7 @@ mod config_action_tests {
let (_, finished_discovery_receiver) = mpsc::channel(2);

let config_info = ConfigInfo {
instance_map: Arc::new(RwLock::new(HashMap::new())),
device_plugin_context: Arc::new(RwLock::new(DevicePluginContext::default())),
stop_discovery_sender: stop_discovery_sender.clone(),
finished_discovery_receiver,
last_generation: Some(1),
Expand Down
56 changes: 48 additions & 8 deletions agent/src/util/device_plugin_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use super::{
KUBELET_SOCKET, LIST_AND_WATCH_MESSAGE_CHANNEL_CAPACITY,
},
device_plugin_service::{
DevicePluginBehavior, DevicePluginService, InstanceDevicePlugin, InstanceMap,
ListAndWatchMessageKind,
ConfigurationDevicePlugin, DevicePluginBehavior, DevicePluginContext, DevicePluginService,
InstanceDevicePlugin, ListAndWatchMessageKind,
},
v1beta1,
v1beta1::{
Expand All @@ -23,11 +23,12 @@ use futures::TryFutureExt;
use log::{info, trace};
#[cfg(test)]
use mockall::{automock, predicate::*};
use std::sync::Arc;
use std::{convert::TryFrom, env, path::Path, time::SystemTime};
use tokio::{
net::UnixListener,
net::UnixStream,
sync::{broadcast, mpsc},
sync::{broadcast, mpsc, RwLock},
task,
};
use tonic::transport::{Endpoint, Server, Uri};
Expand All @@ -42,9 +43,19 @@ pub trait DevicePluginBuilderInterface: Send + Sync {
instance_id: String,
config: &Configuration,
shared: bool,
instance_map: InstanceMap,
device_plugin_context: Arc<RwLock<DevicePluginContext>>,
device: Device,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;

async fn build_configuration_device_plugin(
&self,
device_plugin_name: String,
config: &Configuration,
device_plugin_context: Arc<RwLock<DevicePluginContext>>,
) -> Result<
broadcast::Sender<ListAndWatchMessageKind>,
Box<dyn std::error::Error + Send + Sync + 'static>,
>;
}

/// For each Instance, builds a Device Plugin, registers it with the kubelet, and serves it over UDS.
Expand All @@ -59,7 +70,7 @@ impl DevicePluginBuilderInterface for DevicePluginBuilder {
instance_id: String,
config: &Configuration,
shared: bool,
instance_map: InstanceMap,
device_plugin_context: Arc<RwLock<DevicePluginContext>>,
device: Device,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
info!("build_device_plugin - entered for device {}", instance_name);
Expand All @@ -73,20 +84,49 @@ impl DevicePluginBuilderInterface for DevicePluginBuilder {
self.build_device_plugin_service(
&instance_name,
config,
instance_map,
device_plugin_context,
device_plugin_behavior,
list_and_watch_message_sender,
)
.await
}

/// This creates a new ConfigurationDevicePluginService for a Configuration and registers it with the kubelet
async fn build_configuration_device_plugin(
&self,
device_plugin_name: String,
config: &Configuration,
device_plugin_context: Arc<RwLock<DevicePluginContext>>,
) -> Result<
broadcast::Sender<ListAndWatchMessageKind>,
Box<dyn std::error::Error + Send + Sync + 'static>,
> {
info!(
"build_configuration_device_plugin - entered for device {}",
device_plugin_name
);
let device_plugin_behavior =
DevicePluginBehavior::Configuration(ConfigurationDevicePlugin::default());
let (list_and_watch_message_sender, _) =
broadcast::channel(LIST_AND_WATCH_MESSAGE_CHANNEL_CAPACITY);
self.build_device_plugin_service(
&device_plugin_name,
config,
device_plugin_context,
device_plugin_behavior,
list_and_watch_message_sender.clone(),
)
.await?;
Ok(list_and_watch_message_sender)
}
}

impl DevicePluginBuilder {
async fn build_device_plugin_service(
&self,
device_plugin_name: &str,
config: &Configuration,
instance_map: InstanceMap,
device_plugin_context: Arc<RwLock<DevicePluginContext>>,
device_plugin_behavior: DevicePluginBehavior,
list_and_watch_message_sender: broadcast::Sender<ListAndWatchMessageKind>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
Expand All @@ -108,7 +148,7 @@ impl DevicePluginBuilder {
config_uid: config.metadata.uid.as_ref().unwrap().clone(),
config_namespace: config.metadata.namespace.as_ref().unwrap().clone(),
node_name: env::var("AGENT_NODE_NAME")?,
instance_map,
device_plugin_context,
list_and_watch_message_sender,
server_ender_sender: server_ender_sender.clone(),
device_plugin_behavior,
Expand Down
Loading