Skip to content

Commit

Permalink
disconnect and error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
lulf committed Mar 12, 2024
1 parent 5c4f303 commit b9aaf94
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 78 deletions.
2 changes: 1 addition & 1 deletion host/src/adapter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ where
L2CAP_CID_LE_U_SIGNAL => {
let mut r = ReadCursor::new(packet.payload);
let signal: L2capLeSignal = r.read()?;
match self.channels.control(conn, signal) {
match self.channels.control(conn, signal).await {
Ok(_) => {}
Err(_) => {
return Err(HandleError::Other);
Expand Down
159 changes: 102 additions & 57 deletions host/src/channel_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ use core::{
use bt_hci::param::ConnHandle;
use embassy_sync::{
blocking_mutex::{raw::RawMutex, Mutex},
channel::{Channel, DynamicReceiver},
channel::{Channel, DynamicReceiver, DynamicSendFuture},
waitqueue::WakerRegistration,
};

use crate::{
codec,
l2cap::L2capPacket,
packet_pool::{AllocId, DynamicPacketPool},
pdu::Pdu,
Expand All @@ -22,6 +23,24 @@ use crate::{

const BASE_ID: u16 = 0x40;

#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum Error {
InvalidChannelId,
NoChannelAvailable,
ChannelNotFound,
OutOfMemory,
NotSupported,
ChannelClosed,
Codec(codec::Error),
}

impl From<codec::Error> for Error {
fn from(e: codec::Error) -> Self {
Self::Codec(e)
}
}

struct State<const CHANNELS: usize> {
channels: [ChannelState; CHANNELS],
accept_waker: WakerRegistration,
Expand All @@ -38,9 +57,14 @@ pub struct ChannelManager<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TX
//outbound: [Channel<M, Pdu<'d>, L2CAP_TXQ>; CHANNELS],
}

pub trait DynamicChannelManager {
fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll<Result<(), ()>>;
fn confirm_received(&self, cid: u16, credits: usize) -> Result<(), ()>;
pub trait DynamicChannelManager<'d> {
fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll<Result<(), Error>>;
fn confirm_received(
&'d self,
cid: u16,
credits: usize,
) -> Result<DynamicSendFuture<'d, (ConnHandle, L2capLeSignal)>, Error>;
fn confirm_disconnected(&self, cid: u16) -> Result<(), Error>;
}

impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP_RXQ: usize>
Expand Down Expand Up @@ -69,27 +93,29 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
self.signal.receive().await
}

fn disconnect(&self, cid: u16) -> Result<(), ()> {
self.state.lock(|state| {
async fn disconnect(&self, cid: u16) -> Result<(), Error> {
let idx = self.state.lock(|state| {
let mut state = state.borrow_mut();
for storage in state.channels.iter_mut() {
for (idx, storage) in state.channels.iter_mut().enumerate() {
match storage {
ChannelState::Connecting(state) if cid == state.cid => {
*storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn });
break;
*storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid });
return Ok(idx);
}
ChannelState::Connected(state) if cid == state.cid => {
*storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn });
break;
*storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid });
return Ok(idx);
}
_ => {}
}
}
Ok(())
})
Err(Error::ChannelNotFound)
})?;
self.inbound[idx].send(None).await;
Ok(())
}

fn disconnected(&self, cid: u16) -> Result<(), ()> {
fn disconnected(&self, cid: u16) -> Result<(), Error> {
self.state.lock(|state| {
let mut state = state.borrow_mut();
for storage in state.channels.iter_mut() {
Expand All @@ -109,7 +135,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
})
}

fn connect(&self, mut req: ConnectingState) -> Result<(usize, u16), ()> {
fn connect(&self, mut req: ConnectingState) -> Result<(usize, u16), Error> {
self.state.lock(|state| {
let mut state = state.borrow_mut();
for (idx, storage) in state.channels.iter_mut().enumerate() {
Expand All @@ -121,11 +147,15 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
return Ok((idx, cid));
}
}
Err(())
Err(Error::NoChannelAvailable)
})
}

fn connected<F: FnOnce(usize, &ConnectingState) -> ConnectedState>(&self, request_id: u8, f: F) -> Result<(), ()> {
fn connected<F: FnOnce(usize, &ConnectingState) -> ConnectedState>(
&self,
request_id: u8,
f: F,
) -> Result<(), Error> {
self.state.lock(|state| {
let mut state = state.borrow_mut();
for (idx, storage) in state.channels.iter_mut().enumerate() {
Expand All @@ -139,11 +169,11 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
_ => {}
}
}
Err(())
Err(Error::ChannelNotFound)
})
}

fn remote_credits(&self, cid: u16, credits: u16) -> Result<(), ()> {
fn remote_credits(&self, cid: u16, credits: u16) -> Result<(), Error> {
self.state.lock(|state| {
let mut state = state.borrow_mut();
for (idx, storage) in state.channels.iter_mut().enumerate() {
Expand All @@ -156,7 +186,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
_ => {}
}
}
Err(())
Err(Error::ChannelNotFound)
})
}

Expand Down Expand Up @@ -190,7 +220,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
conn: ConnHandle,
psm: u16,
mut mtu: u16,
) -> Result<(ConnectedState, DynamicReceiver<'_, Option<Pdu<'d>>>), ()> {
) -> Result<(ConnectedState, DynamicReceiver<'_, Option<Pdu<'d>>>), Error> {
let mut req_id = 0;
let (idx, state) = poll_fn(|cx| {
self.poll_accept(conn, psm, cx, |idx, req| {
Expand Down Expand Up @@ -233,7 +263,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
conn: ConnHandle,
psm: u16,
mtu: u16,
) -> Result<(ConnectedState, DynamicReceiver<'_, Option<Pdu<'d>>>), ()> {
) -> Result<(ConnectedState, DynamicReceiver<'_, Option<Pdu<'d>>>), Error> {
let state = ConnectingState {
conn,
cid: 0,
Expand Down Expand Up @@ -281,14 +311,14 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
Ok((state, rx))
}

pub async fn dispatch(&self, packet: L2capPacket<'_>) -> Result<(), ()> {
pub async fn dispatch(&self, packet: L2capPacket<'_>) -> Result<(), Error> {
if packet.channel < BASE_ID {
return Err(());
return Err(Error::InvalidChannelId);
}

let chan = (packet.channel - BASE_ID) as usize;
if chan > self.inbound.len() {
return Err(());
return Err(Error::InvalidChannelId);
}

let chan_alloc = AllocId::dynamic(chan);
Expand All @@ -299,15 +329,15 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
Ok(())
} else {
warn!("No memory for channel {}", packet.channel);
Err(())
Err(Error::OutOfMemory)
}
}

pub fn control(&self, conn: ConnHandle, signal: L2capLeSignal) -> Result<(), ()> {
pub async fn control(&self, conn: ConnHandle, signal: L2capLeSignal) -> Result<(), Error> {
// info!("Inbound signal: {:?}", signal);
match signal.data {
L2capLeSignalData::LeCreditConnReq(req) => {
if let Err(e) = self.connect(ConnectingState {
self.connect(ConnectingState {
conn,
cid: 0,
psm: req.psm,
Expand All @@ -316,17 +346,14 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
initial_credits: req.credits,
mps: req.mps,
mtu: req.mtu,
}) {
warn!("Error accepting connection: {:?}", e);
return Err(());
}
})?;
Ok(())
}
L2capLeSignalData::LeCreditConnRes(res) => {
match res.result {
LeCreditConnResultCode::Success => {
// Must be a response of a previous request which should already by allocated a channel for
match self.connected(signal.id, |idx, req| ConnectedState {
self.connected(signal.id, |idx, req| ConnectedState {
conn: req.conn,
cid: req.cid,
psm: req.psm,
Expand All @@ -336,14 +363,12 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
pool_id: AllocId::dynamic(idx),
mps: req.mps.min(res.mps),
mtu: req.mtu.min(res.mtu),
}) {
Ok(bound) => Ok(()),
Err(_) => Err(()),
}
})?;
Ok(())
}
other => {
warn!("Channel open request failed: {:?}", other);
Ok(())
Err(Error::NotSupported)
}
}
}
Expand All @@ -356,7 +381,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
Ok(())
}
L2capLeSignalData::DisconnectionReq(req) => {
warn!("Disconnection requested!");
self.disconnect(req.dcid).await?;
Ok(())
}
L2capLeSignalData::DisconnectionRes(res) => {
Expand All @@ -367,40 +392,59 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
}
}

impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP_RXQ: usize> DynamicChannelManager
impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP_RXQ: usize> DynamicChannelManager<'d>
for ChannelManager<'d, M, CHANNELS, L2CAP_TXQ, L2CAP_RXQ>
{
fn confirm_received(&self, cid: u16, credits: usize) -> Result<(), ()> {
self.state.lock(|state| {
fn confirm_received(
&'d self,
cid: u16,
credits: usize,
) -> Result<DynamicSendFuture<'d, (ConnHandle, L2capLeSignal)>, Error> {
let (conn, signal) = self.state.lock(|state| {
let mut state = state.borrow_mut();
for (idx, storage) in state.channels.iter_mut().enumerate() {
match storage {
ChannelState::Connected(state) if cid == state.cid => {
// Don't set credits higher than what we can promise
let increment = self.pool.min_available(AllocId::dynamic(idx)).min(credits);
state.credits += increment as u16;
self.signal
.try_send((
state.conn,
L2capLeSignal::new(
(cid % 255) as u8,
L2capLeSignalData::LeCreditFlowInd(LeCreditFlowInd {
cid: state.peer_cid,
credits: increment as u16,
}),
),
))
.map_err(|_| ())?;
return Ok((
state.conn,
L2capLeSignal::new(
(cid % 255) as u8,
L2capLeSignalData::LeCreditFlowInd(LeCreditFlowInd {
cid: state.peer_cid,
credits: increment as u16,
}),
),
));
}
_ => {}
}
}
return Err(Error::ChannelNotFound);
})?;
let f = self.signal.send((conn, signal));
Ok(f.into())
}

fn confirm_disconnected(&self, cid: u16) -> Result<(), Error> {
self.state.lock(|state| {
let mut state = state.borrow_mut();
for storage in state.channels.iter_mut() {
match storage {
ChannelState::Disconnecting(state) if cid == state.cid => {
*storage = ChannelState::Disconnected;
return Ok(());
}
_ => {}
}
}
return Err(());
return Err(Error::ChannelNotFound);
})
}

fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll<Result<(), ()>> {
fn poll_request_to_send(&self, cid: u16, credits: usize, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.state.lock(|state| {
let mut state = state.borrow_mut();
for (idx, storage) in state.channels.iter_mut().enumerate() {
Expand All @@ -417,7 +461,7 @@ impl<'d, M: RawMutex, const CHANNELS: usize, const L2CAP_TXQ: usize, const L2CAP
_ => {}
}
}
return Poll::Ready(Err(()));
return Poll::Ready(Err(Error::ChannelNotFound));
})
}
}
Expand Down Expand Up @@ -459,4 +503,5 @@ pub struct ConnectedState {

pub struct DisconnectingState {
pub(crate) conn: ConnHandle,
pub(crate) cid: u16,
}
Loading

0 comments on commit b9aaf94

Please sign in to comment.