diff --git a/host/src/adapter.rs b/host/src/adapter.rs index f58d2623..4a1522bf 100644 --- a/host/src/adapter.rs +++ b/host/src/adapter.rs @@ -179,7 +179,7 @@ where L2CAP_CID_LE_U_SIGNAL => { let mut r = ReadCursor::new(packet.payload); let signal: L2capLeSignal = r.read()?; - match self.channels.control(conn, signal) { + match self.channels.control(conn, signal).await { Ok(_) => {} Err(_) => { return Err(HandleError::Other); diff --git a/host/src/channel_manager.rs b/host/src/channel_manager.rs index efc7417f..7c66024d 100644 --- a/host/src/channel_manager.rs +++ b/host/src/channel_manager.rs @@ -7,11 +7,12 @@ use core::{ use bt_hci::param::ConnHandle; use embassy_sync::{ blocking_mutex::{raw::RawMutex, Mutex}, - channel::{Channel, DynamicReceiver}, + channel::{Channel, DynamicReceiver, DynamicSendFuture}, waitqueue::WakerRegistration, }; use crate::{ + codec, l2cap::L2capPacket, packet_pool::{AllocId, DynamicPacketPool}, pdu::Pdu, @@ -22,6 +23,24 @@ use crate::{ const BASE_ID: u16 = 0x40; +#[derive(Clone, Copy, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Error { + InvalidChannelId, + NoChannelAvailable, + ChannelNotFound, + OutOfMemory, + NotSupported, + ChannelClosed, + Codec(codec::Error), +} + +impl From for Error { + fn from(e: codec::Error) -> Self { + Self::Codec(e) + } +} + struct State { channels: [ChannelState; CHANNELS], accept_waker: WakerRegistration, @@ -38,9 +57,14 @@ pub struct ChannelManager<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TX //outbound: [Channel, L2CAP_TXQ>; CHANNELS], } -pub trait DynamicChannelManager { - fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll>; - fn confirm_received(&self, cid: u16, credits: usize) -> Result<(), ()>; +pub trait DynamicChannelManager<'d> { + fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll>; + fn confirm_received( + &'d self, + cid: u16, + credits: usize, + ) -> Result, Error>; + fn confirm_disconnected(&self, cid: u16) -> Result<(), Error>; } impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP_RXQ: usize> @@ -69,27 +93,29 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP self.signal.receive().await } - fn disconnect(&self, cid: u16) -> Result<(), ()> { - self.state.lock(|state| { + async fn disconnect(&self, cid: u16) -> Result<(), Error> { + let idx = self.state.lock(|state| { let mut state = state.borrow_mut(); - for storage in state.channels.iter_mut() { + for (idx, storage) in state.channels.iter_mut().enumerate() { match storage { ChannelState::Connecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn }); - break; + *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); + return Ok(idx); } ChannelState::Connected(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn }); - break; + *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); + return Ok(idx); } _ => {} } } - Ok(()) - }) + Err(Error::ChannelNotFound) + })?; + self.inbound[idx].send(None).await; + Ok(()) } - fn disconnected(&self, cid: u16) -> Result<(), ()> { + fn disconnected(&self, cid: u16) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.channels.iter_mut() { @@ -109,7 +135,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP }) } - fn connect(&self, mut req: ConnectingState) -> Result<(usize, u16), ()> { + fn connect(&self, mut req: ConnectingState) -> Result<(usize, u16), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { @@ -121,11 +147,15 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP return Ok((idx, cid)); } } - Err(()) + Err(Error::NoChannelAvailable) }) } - fn connected ConnectedState>(&self, request_id: u8, f: F) -> Result<(), ()> { + fn connected ConnectedState>( + &self, + request_id: u8, + f: F, + ) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { @@ -139,11 +169,11 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP _ => {} } } - Err(()) + Err(Error::ChannelNotFound) }) } - fn remote_credits(&self, cid: u16, credits: u16) -> Result<(), ()> { + fn remote_credits(&self, cid: u16, credits: u16) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { @@ -156,7 +186,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP _ => {} } } - Err(()) + Err(Error::ChannelNotFound) }) } @@ -190,7 +220,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP conn: ConnHandle, psm: u16, mut mtu: u16, - ) -> Result<(ConnectedState, DynamicReceiver<'_, Option>>), ()> { + ) -> Result<(ConnectedState, DynamicReceiver<'_, Option>>), Error> { let mut req_id = 0; let (idx, state) = poll_fn(|cx| { self.poll_accept(conn, psm, cx, |idx, req| { @@ -233,7 +263,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP conn: ConnHandle, psm: u16, mtu: u16, - ) -> Result<(ConnectedState, DynamicReceiver<'_, Option>>), ()> { + ) -> Result<(ConnectedState, DynamicReceiver<'_, Option>>), Error> { let state = ConnectingState { conn, cid: 0, @@ -281,14 +311,14 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP Ok((state, rx)) } - pub async fn dispatch(&self, packet: L2capPacket<'_>) -> Result<(), ()> { + pub async fn dispatch(&self, packet: L2capPacket<'_>) -> Result<(), Error> { if packet.channel < BASE_ID { - return Err(()); + return Err(Error::InvalidChannelId); } let chan = (packet.channel - BASE_ID) as usize; if chan > self.inbound.len() { - return Err(()); + return Err(Error::InvalidChannelId); } let chan_alloc = AllocId::dynamic(chan); @@ -299,15 +329,15 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP Ok(()) } else { warn!("No memory for channel {}", packet.channel); - Err(()) + Err(Error::OutOfMemory) } } - pub fn control(&self, conn: ConnHandle, signal: L2capLeSignal) -> Result<(), ()> { + pub async fn control(&self, conn: ConnHandle, signal: L2capLeSignal) -> Result<(), Error> { // info!("Inbound signal: {:?}", signal); match signal.data { L2capLeSignalData::LeCreditConnReq(req) => { - if let Err(e) = self.connect(ConnectingState { + self.connect(ConnectingState { conn, cid: 0, psm: req.psm, @@ -316,17 +346,14 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP initial_credits: req.credits, mps: req.mps, mtu: req.mtu, - }) { - warn!("Error accepting connection: {:?}", e); - return Err(()); - } + })?; Ok(()) } L2capLeSignalData::LeCreditConnRes(res) => { match res.result { LeCreditConnResultCode::Success => { // Must be a response of a previous request which should already by allocated a channel for - match self.connected(signal.id, |idx, req| ConnectedState { + self.connected(signal.id, |idx, req| ConnectedState { conn: req.conn, cid: req.cid, psm: req.psm, @@ -336,14 +363,12 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP pool_id: AllocId::dynamic(idx), mps: req.mps.min(res.mps), mtu: req.mtu.min(res.mtu), - }) { - Ok(bound) => Ok(()), - Err(_) => Err(()), - } + })?; + Ok(()) } other => { warn!("Channel open request failed: {:?}", other); - Ok(()) + Err(Error::NotSupported) } } } @@ -356,7 +381,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP Ok(()) } L2capLeSignalData::DisconnectionReq(req) => { - warn!("Disconnection requested!"); + self.disconnect(req.dcid).await?; Ok(()) } L2capLeSignalData::DisconnectionRes(res) => { @@ -367,11 +392,15 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP } } -impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP_RXQ: usize> DynamicChannelManager +impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP_RXQ: usize> DynamicChannelManager<'d> for ChannelManager<'d, M, CHANNELS, L2CAP_TXQ, L2CAP_RXQ> { - fn confirm_received(&self, cid: u16, credits: usize) -> Result<(), ()> { - self.state.lock(|state| { + fn confirm_received( + &'d self, + cid: u16, + credits: usize, + ) -> Result, Error> { + let (conn, signal) = self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { match storage { @@ -379,28 +408,43 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP // Don't set credits higher than what we can promise let increment = self.pool.min_available(AllocId::dynamic(idx)).min(credits); state.credits += increment as u16; - self.signal - .try_send(( - state.conn, - L2capLeSignal::new( - (cid % 255) as u8, - L2capLeSignalData::LeCreditFlowInd(LeCreditFlowInd { - cid: state.peer_cid, - credits: increment as u16, - }), - ), - )) - .map_err(|_| ())?; + return Ok(( + state.conn, + L2capLeSignal::new( + (cid % 255) as u8, + L2capLeSignalData::LeCreditFlowInd(LeCreditFlowInd { + cid: state.peer_cid, + credits: increment as u16, + }), + ), + )); + } + _ => {} + } + } + return Err(Error::ChannelNotFound); + })?; + let f = self.signal.send((conn, signal)); + Ok(f.into()) + } + + fn confirm_disconnected(&self, cid: u16) -> Result<(), Error> { + self.state.lock(|state| { + let mut state = state.borrow_mut(); + for storage in state.channels.iter_mut() { + match storage { + ChannelState::Disconnecting(state) if cid == state.cid => { + *storage = ChannelState::Disconnected; return Ok(()); } _ => {} } } - return Err(()); + return Err(Error::ChannelNotFound); }) } - fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll> { + fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { @@ -417,7 +461,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP _ => {} } } - return Poll::Ready(Err(())); + return Poll::Ready(Err(Error::ChannelNotFound)); }) } } @@ -459,4 +503,5 @@ pub struct ConnectedState { pub struct DisconnectingState { pub(crate) conn: ConnHandle, + pub(crate) cid: u16, } diff --git a/host/src/l2cap.rs b/host/src/l2cap.rs index 079c1c85..29bb1058 100644 --- a/host/src/l2cap.rs +++ b/host/src/l2cap.rs @@ -1,7 +1,7 @@ use core::future::poll_fn; use crate::adapter::Adapter; -use crate::channel_manager::DynamicChannelManager; +use crate::channel_manager::{self, DynamicChannelManager}; use crate::codec; use crate::connection::Connection; use crate::cursor::{ReadCursor, WriteCursor}; @@ -16,6 +16,8 @@ pub(crate) const L2CAP_CID_ATT: u16 = 0x0004; pub(crate) const L2CAP_CID_LE_U_SIGNAL: u16 = 0x0005; pub(crate) const L2CAP_CID_DYN_START: u16 = 0x0040; +pub use channel_manager::Error; + #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[derive(Debug)] pub struct L2capPacket<'d> { @@ -51,13 +53,13 @@ pub struct L2capChannel<'d, const MTU: usize> { peer_cid: u16, mps: usize, pool: &'d dyn DynamicPacketPool<'d>, - manager: &'d dyn DynamicChannelManager, + manager: &'d dyn DynamicChannelManager<'d>, rx: DynamicReceiver<'d, Option>>, tx: DynamicSender<'d, (ConnHandle, Pdu<'d>)>, } impl<'d, const MTU: usize> L2capChannel<'d, MTU> { - pub async fn send(&mut self, buf: &[u8]) -> Result<(), ()> { + pub async fn send(&mut self, buf: &[u8]) -> Result<(), Error> { // The number of packets we'll need to send for this payload let n_packets = 1 + (buf.len().saturating_sub(self.mps - 2)).div_ceil(self.mps); @@ -65,7 +67,7 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { // for pool to get the available packets back, which would require some poll/async behavior // support for the pool. if self.pool.available(self.pool_id) < n_packets { - return Err(()); + return Err(Error::OutOfMemory); } poll_fn(|cx| self.manager.poll_request_to_send(self.cid, n_packets, cx)).await?; @@ -75,11 +77,11 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { if let Some(mut packet) = self.pool.alloc(self.pool_id) { let len = { let mut w = WriteCursor::new(packet.as_mut()); - w.write(2 + first.len() as u16).map_err(|_| ())?; - w.write(self.peer_cid as u16).map_err(|_| ())?; + w.write(2 + first.len() as u16)?; + w.write(self.peer_cid as u16)?; let len = buf.len() as u16; - w.write(len).map_err(|_| ())?; - w.append(first).map_err(|_| ())?; + w.write(len)?; + w.append(first)?; w.len() }; let pdu = if remaining.is_empty() { @@ -89,7 +91,7 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { }; self.tx.send((self.conn, pdu)).await; } else { - return Err(()); + return Err(Error::OutOfMemory); } let chunks = remaining.chunks(self.mps); @@ -98,9 +100,9 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { if let Some(mut packet) = self.pool.alloc(self.pool_id) { let len = { let mut w = WriteCursor::new(packet.as_mut()); - w.write(chunk.len() as u16).map_err(|_| ())?; - w.write(self.peer_cid as u16).map_err(|_| ())?; - w.append(chunk).map_err(|_| ())?; + w.write(chunk.len() as u16)?; + w.write(self.peer_cid as u16)?; + w.append(chunk)?; w.len() }; let pdu = if i == num_chunks - 1 { @@ -110,18 +112,28 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { }; self.tx.send((self.conn, pdu)).await; } else { - return Err(()); + return Err(Error::OutOfMemory); } } Ok(()) } - pub async fn receive(&mut self, buf: &mut [u8]) -> Result { + async fn receive_pdu(&mut self) -> Result, Error> { + match self.rx.receive().await { + Some(pdu) => Ok(pdu), + None => { + self.manager.confirm_disconnected(self.cid)?; + Err(Error::ChannelClosed) + } + } + } + + pub async fn receive(&mut self, buf: &mut [u8]) -> Result { let mut n_received = 1; - let packet = self.rx.receive().await.ok_or(())?; + let packet = self.receive_pdu().await?; let mut r = ReadCursor::new(&packet.as_ref()); - let remaining: u16 = r.read().map_err(|_| ())?; + let remaining: u16 = r.read()?; let data = r.remaining(); let to_copy = data.len().min(buf.len()); @@ -131,7 +143,7 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { let mut remaining = remaining as usize - data.len(); // We have some k-frames to reassemble while remaining > 0 { - let packet = self.rx.receive().await.ok_or(())?; + let packet = self.receive_pdu().await?; n_received += 1; let to_copy = packet.len.min(buf.len() - pos); if to_copy > 0 { @@ -141,7 +153,7 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { remaining -= packet.len; } - self.manager.confirm_received(self.cid, n_received)?; + self.manager.confirm_received(self.cid, n_received)?.await; Ok(pos) } @@ -156,7 +168,7 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { adapter: &'d Adapter<'d, M, CONNS, CHANNELS, L2CAP_TXQ, L2CAP_RXQ>, connection: &Connection<'d>, psm: u16, - ) -> Result { + ) -> Result { let connections = &adapter.connections; let channels = &adapter.channels; @@ -186,7 +198,7 @@ impl<'d, const MTU: usize> L2capChannel<'d, MTU> { adapter: &'d Adapter<'d, M, CONNS, CHANNELS, L2CAP_TXQ, L2CAP_RXQ>, connection: &Connection<'d>, psm: u16, - ) -> Result { + ) -> Result { // TODO: Use unique signal ID to ensure no collision of signal messages // let (state, rx) = adapter.channels.create(connection.handle(), psm, MTU as u16).await?;