diff --git a/host/src/adapter.rs b/host/src/adapter.rs index 97991a1f..4410df29 100644 --- a/host/src/adapter.rs +++ b/host/src/adapter.rs @@ -28,7 +28,7 @@ use futures_intrusive::sync::LocalSemaphore; use crate::advertise::{Advertisement, AdvertisementConfig, RawAdvertisement}; use crate::channel_manager::ChannelManager; use crate::connection::{ConnectConfig, Connection}; -use crate::connection_manager::{ConnectionInfo, ConnectionManager}; +use crate::connection_manager::ConnectionManager; use crate::cursor::WriteCursor; use crate::l2cap::sar::PacketReassembly; use crate::packet_pool::{AllocId, DynamicPacketPool, PacketPool, Qos}; @@ -537,7 +537,7 @@ where // Avoids using the packet buffer for signalling packets if header.channel == L2CAP_CID_LE_U_SIGNAL { assert!(data.len() == header.length as usize); - self.channels.control(acl.handle(), &data).await?; + self.channels.signal(acl.handle(), &data).await?; return Ok(()); } @@ -697,20 +697,7 @@ where Event::Le(event) => match event { LeEvent::LeConnectionComplete(e) => match e.status.to_result() { Ok(_) => { - if let Err(err) = self.connections.connect( - e.handle, - ConnectionInfo { - handle: e.handle, - status: e.status, - role: e.role, - peer_addr_kind: e.peer_addr_kind, - peer_address: e.peer_addr, - interval: e.conn_interval.as_u16(), - latency: e.peripheral_latency, - timeout: e.supervision_timeout.as_u16(), - att_mtu: 23, - }, - ) { + if let Err(err) = self.connections.connect(e.handle, &e) { warn!("Error establishing connection: {:?}", err); self.command(Disconnect::new( e.handle, @@ -747,7 +734,7 @@ where disconnects += 1; info!("Disconnected (total {}): {:?}", disconnects, e); let _ = self.connections.disconnect(e.handle); - let _ = self.channels.disconnected_connection(e.handle); + let _ = self.channels.disconnected(e.handle); } Event::NumberOfCompletedPackets(c) => { // info!("Confirmed {} packets sent", c.completed_packets.len()); diff --git a/host/src/channel_manager.rs b/host/src/channel_manager.rs index eb1c3394..be5065ab 100644 --- a/host/src/channel_manager.rs +++ b/host/src/channel_manager.rs @@ -24,7 +24,7 @@ const BASE_ID: u16 = 0x40; struct State { next_req_id: u8, - channels: [ChannelState; CHANNELS], + channels: [ChannelStorage; CHANNELS], accept_waker: WakerRegistration, create_waker: WakerRegistration, credit_wakers: [WakerRegistration; CHANNELS], @@ -55,14 +55,13 @@ impl< { const TX_CHANNEL: Channel, L2CAP_TXQ> = Channel::new(); const RX_CHANNEL: Channel>, L2CAP_RXQ> = Channel::new(); - const DISCONNECTED: ChannelState = ChannelState::Disconnected; const CREDIT_WAKER: WakerRegistration = WakerRegistration::new(); pub fn new(pool: &'d dyn DynamicPacketPool<'d>) -> Self { Self { pool, state: Mutex::new(RefCell::new(State { next_req_id: 0, - channels: [Self::DISCONNECTED; CHANNELS], + channels: [ChannelStorage::DISCONNECTED; CHANNELS], accept_waker: WakerRegistration::new(), create_waker: WakerRegistration::new(), credit_wakers: [Self::CREDIT_WAKER; CHANNELS], @@ -84,79 +83,55 @@ impl< }) } - pub(crate) fn disconnect(&self, cid: u16) -> Result<(), Error> { - let idx = self.state.lock(|state| { + pub(crate) fn disconnect(&self, cid: u16) -> Result { + let handle = self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Disconnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnected; - return Ok(idx); + match storage.state { + ChannelState::Disconnecting if cid == storage.cid => { + storage.state = ChannelState::Disconnected; + storage.cid = 0; + let _ = self.inbound[idx].try_send(None); + return Ok(storage.conn); } - ChannelState::PeerConnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - return Ok(idx); + ChannelState::PeerConnecting(_) if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; + let _ = self.inbound[idx].try_send(None); + return Ok(storage.conn); } - ChannelState::Connecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - return Ok(idx); + ChannelState::Connecting(_) if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; + let _ = self.inbound[idx].try_send(None); + return Ok(storage.conn); } - ChannelState::Connected(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - return Ok(idx); + ChannelState::Connected if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; + let _ = self.inbound[idx].try_send(None); + return Ok(storage.conn); } _ => {} } } Err(Error::NotFound) })?; - let _ = self.inbound[idx].try_send(None); - Ok(()) + Ok(ConnHandle::new(handle)) } - fn disconnected(&self, cid: u16) -> Result<(), Error> { - self.state.lock(|state| { - let mut state = state.borrow_mut(); - for storage in state.channels.iter_mut() { - match storage { - ChannelState::Disconnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnected; - break; - } - ChannelState::PeerConnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - break; - } - ChannelState::Connecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - break; - } - ChannelState::Connected(state) if cid == state.cid => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn: state.conn, cid }); - break; - } - _ => {} - } - } - Ok(()) - }) - } - - pub fn disconnected_connection(&self, conn: ConnHandle) -> Result<(), Error> { + pub(crate) fn disconnected(&self, conn: ConnHandle) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::PeerConnecting(state) if conn == state.conn => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn, cid: state.cid }); + match storage.state { + ChannelState::PeerConnecting(_) if conn.raw() == storage.conn => { + storage.state = ChannelState::Disconnecting; let _ = self.inbound[idx].try_send(None); } - ChannelState::Connecting(state) if conn == state.conn => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn, cid: state.cid }); + ChannelState::Connecting(_) if conn.raw() == storage.conn => { + storage.state = ChannelState::Disconnecting; let _ = self.inbound[idx].try_send(None); } - ChannelState::Connected(state) if conn == state.conn => { - *storage = ChannelState::Disconnecting(DisconnectingState { conn, cid: state.cid }); + ChannelState::Connected if conn.raw() == storage.conn => { + storage.state = ChannelState::Disconnecting; let _ = self.inbound[idx].try_send(None); } _ => {} @@ -171,16 +146,14 @@ impl< Ok(()) } - fn peer_connect PeerConnectingState>(&self, f: F) -> Result<(), Error> { + fn alloc(&self, f: F) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - if let ChannelState::Disconnected = storage { + if let ChannelState::Disconnected = storage.state { let cid: u16 = BASE_ID + idx as u16; - let mut req = f(idx, cid); - req.cid = cid; - *storage = ChannelState::PeerConnecting(req); - state.accept_waker.wake(); + storage.cid = cid; + f(storage); return Ok(()); } } @@ -188,127 +161,52 @@ impl< }) } - fn connect ConnectingState>(&self, f: F) -> Result<(), Error> { - self.state.lock(|state| { - let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - if let ChannelState::Disconnected = storage { - let cid: u16 = BASE_ID + idx as u16; - let mut req = f(idx, cid); - req.cid = cid; - *storage = ChannelState::Connecting(req); - return Ok(()); - } - } - Err(Error::NoChannelAvailable) - }) - } - - fn connected ConnectedState>( - &self, - request_id: u8, - f: F, - ) -> Result<(), Error> { - self.state.lock(|state| { - let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connecting(req) if request_id == req.request_id => { - let res = f(idx, req); - // info!("Connection created, properties: {:?}", res); - *storage = ChannelState::Connected(res); - state.create_waker.wake(); - return Ok(()); - } - _ => {} - } - } - Err(Error::NotFound) - }) - } - - fn remote_credits(&self, cid: u16, credits: u16) -> Result<(), Error> { - self.state.lock(|state| { - let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connected(s) if s.peer_cid == cid => { - s.peer_credits += credits; - state.credit_wakers[idx].wake(); - return Ok(()); - } - _ => {} - } - } - Err(Error::NotFound) - }) - } - - fn poll_accept ConnectedState>( - &self, - conn: ConnHandle, - psm: &[u16], - cx: &mut Context<'_>, - f: F, - ) -> Poll<(usize, ConnectedState)> { - self.state.lock(|state| { - let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::PeerConnecting(req) if req.conn == conn && psm.contains(&req.psm) => { - let state = f(idx, req); - let cid = state.cid; - *storage = ChannelState::Connected(state.clone()); - return Poll::Ready((idx, state)); - } - _ => {} - } - } - state.accept_waker.register(cx.waker()); - Poll::Pending - }) - } - pub(crate) async fn accept( &self, conn: ConnHandle, psm: &[u16], - mut mtu: u16, + mtu: u16, credit_flow: CreditFlowPolicy, initial_credits: Option, controller: &HciController<'_, T>, ) -> Result> { - let mut req_id = 0; - let (idx, state) = poll_fn(|cx| { - self.poll_accept(conn, psm, cx, |idx, req| { - req_id = req.request_id; - let mps = req.mps.min(self.pool.mtu() as u16 - 4); - mtu = req.mtu.min(mtu); - let credits = initial_credits.unwrap_or(self.pool.min_available(AllocId::dynamic(idx)) as u16); - // info!("Accept L2CAP, initial credits: {}", credits); - ConnectedState { - conn: req.conn, - cid: req.cid, - psm: req.psm, - flow_control: CreditFlowControl::new(credit_flow, credits), - peer_credits: req.offered_credits, - peer_cid: req.peer_cid, - pool_id: AllocId::dynamic(idx), - mps, - mtu, + // Wait until we find a channel for our connection in the connecting state matching our PSM. + let (req_id, mps, mtu, cid, credits) = poll_fn(|cx| { + self.state.lock(|state| { + let mut state = state.borrow_mut(); + for chan in state.channels.iter_mut() { + match chan.state { + ChannelState::PeerConnecting(req_id) if chan.conn == conn.raw() && psm.contains(&chan.psm) => { + chan.mps = chan.mps.min(self.pool.mtu() as u16 - 4); + chan.mtu = chan.mtu.min(mtu); + chan.mtu = mtu; + chan.flow_control = CreditFlowControl::new( + credit_flow, + initial_credits + .unwrap_or(self.pool.min_available(AllocId::from_channel(chan.cid)) as u16), + ); + chan.state = ChannelState::Connected; + + return Poll::Ready((req_id, chan.mps, chan.mtu, chan.cid, chan.flow_control.available())); + } + _ => {} + } } + state.accept_waker.register(cx.waker()); + Poll::Pending }) }) .await; let mut tx = [0; 18]; + // Respond that we accept the channel. controller .signal( conn, req_id, &LeCreditConnRes { - mps: state.mps, - dcid: state.cid, + mps, + dcid: cid, mtu, credits: 0, result: LeCreditConnResultCode::Success, @@ -319,20 +217,11 @@ impl< // Send initial credits let next_req_id = self.next_request_id(); - controller - .signal( - conn, - next_req_id, - &LeCreditFlowInd { - cid: state.cid, - credits: state.flow_control.available(), - }, - &mut tx[..], - ) + .signal(conn, next_req_id, &LeCreditFlowInd { cid, credits }, &mut tx[..]) .await?; - Ok(state.cid) + Ok(cid) } pub(crate) async fn create( @@ -347,46 +236,41 @@ impl< let req_id = self.next_request_id(); let mut credits = 0; let mut cid: u16 = 0; - self.connect(|i, c| { - cid = c; - credits = initial_credits.unwrap_or(self.pool.min_available(AllocId::dynamic(i)) as u16); - ConnectingState { - conn, - cid, - request_id: req_id, - psm, - initial_credits: credits, - flow_control_policy: credit_flow, - mps: self.pool.mtu() as u16 - 4, - mtu, - } + let mps = self.pool.mtu() as u16 - 4; + + // Allocate space for our new channel. + self.alloc(|storage| { + cid = storage.cid; + credits = initial_credits.unwrap_or(self.pool.min_available(AllocId::from_channel(storage.cid)) as u16); + storage.mps = mps; + storage.mtu = mtu; + storage.flow_control = CreditFlowControl::new(credit_flow, credits); + + storage.state = ChannelState::Connecting(req_id); })?; - //info!("Created connect state with idx cid {}", cid); - // - let mut tx = [0; 18]; + let mut tx = [0; 18]; + // Send the initial connect request. let command = LeCreditConnReq { psm, - mps: self.pool.mtu() as u16 - 4, + mps, scid: cid, mtu, credits: 0, }; - - //info!("Signal packet to remote: {:?}", command); controller.signal(conn, req_id, &command, &mut tx[..]).await?; - // info!("Sent signal packet to remote, awaiting response"); - let (idx, cid) = poll_fn(|cx| { + // Wait until a response is accepted. + poll_fn(|cx| { self.state.lock(|state| { let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Disconnecting(req) if req.conn == conn && req.cid == cid => { + for storage in state.channels.iter_mut() { + match storage.state { + ChannelState::Disconnecting if storage.conn == conn.raw() && storage.cid == cid => { return Poll::Ready(Err(Error::Disconnected)); } - ChannelState::Connected(req) if req.conn == conn && req.cid == cid => { - return Poll::Ready(Ok((idx, req.cid))); + ChannelState::Connected if storage.conn == conn.raw() && storage.cid == cid => { + return Poll::Ready(Ok(())); } _ => {} } @@ -397,19 +281,16 @@ impl< }) .await?; - // info!("Peer setup cid {} Sending initial credits", state.peer_cid); - // Send initial credits let next_req_id = self.next_request_id(); let req = controller .signal(conn, next_req_id, &LeCreditFlowInd { cid, credits }, &mut tx[..]) .await?; - - // info!("Done!"); Ok(cid) } - pub async fn dispatch(&self, header: L2capHeader, packet: Packet<'d>) -> Result<(), Error> { + /// Dispatch an incoming L2CAP packet to the appropriate channel. + pub(crate) async fn dispatch(&self, header: L2capHeader, packet: Packet<'d>) -> Result<(), Error> { if header.channel < BASE_ID { return Err(Error::InvalidChannelId); } @@ -418,17 +299,17 @@ impl< if chan > self.inbound.len() { return Err(Error::InvalidChannelId); } + trace!("[l2cap] inbound data packet for {}", header.channel); self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connected(state) if header.channel == state.cid => { - if state.flow_control.available() == 0 { - // info!("No credits available on channel {}", state.cid); + match storage.state { + ChannelState::Connected if header.channel == storage.cid => { + if storage.flow_control.available() == 0 { return Err(Error::OutOfMemory); } - state.flow_control.received(1); + storage.flow_control.received(1); } _ => {} } @@ -442,52 +323,22 @@ impl< Ok(()) } - pub async fn control(&self, conn: ConnHandle, data: &[u8]) -> Result<(), Error> { - // info!("Inbound signal: {:?}", signal); + /// Handle incoming L2CAP signal + pub(crate) async fn signal(&self, conn: ConnHandle, data: &[u8]) -> Result<(), Error> { let (header, data) = L2capSignalHeader::from_hci_bytes(data)?; + trace!("[l2cap] inbound signal code {:?}", header.code); match header.code { L2capSignalCode::LeCreditConnReq => { let req = LeCreditConnReq::from_hci_bytes_complete(data)?; - self.peer_connect(|i, c| PeerConnectingState { - conn, - cid: c, - psm: req.psm, - request_id: header.identifier, - peer_cid: req.scid, - offered_credits: req.credits, - mps: req.mps, - mtu: req.mtu, - })?; - Ok(()) + self.handle_connect_request(conn, header.identifier, &req) } L2capSignalCode::LeCreditConnRes => { let res = LeCreditConnRes::from_hci_bytes_complete(data)?; - // info!("Got response to create request: {:?}", res); - match res.result { - LeCreditConnResultCode::Success => { - // Must be a response of a previous request which should already by allocated a channel for - self.connected(header.identifier, |idx, req| ConnectedState { - conn: req.conn, - cid: req.cid, - psm: req.psm, - flow_control: CreditFlowControl::new(req.flow_control_policy, req.initial_credits), - peer_credits: res.credits, - peer_cid: res.dcid, - pool_id: AllocId::dynamic(idx), - mps: req.mps.min(res.mps), - mtu: req.mtu.min(res.mtu), - })?; - Ok(()) - } - other => { - warn!("Channel open request failed: {:?}", other); - Err(Error::NotSupported) - } - } + self.handle_connect_response(conn, header.identifier, &res) } L2capSignalCode::LeCreditFlowInd => { let req = LeCreditFlowInd::from_hci_bytes_complete(data)?; - self.remote_credits(req.cid, req.credits)?; + self.handle_credit_flow(&req)?; Ok(()) } L2capSignalCode::CommandRejectRes => { @@ -497,31 +348,72 @@ impl< } L2capSignalCode::DisconnectionReq => { let req = DisconnectionReq::from_hci_bytes_complete(data)?; - info!("Disconnect request: {:?}!", req); self.disconnect(req.dcid)?; Ok(()) } L2capSignalCode::DisconnectionRes => { let res = DisconnectionRes::from_hci_bytes_complete(data)?; - warn!("Disconnection result!"); - self.disconnected(res.dcid)?; - Ok(()) + self.handle_disconnect_response(&res) } _ => Err(Error::NotSupported), } } - fn with_connected_channel R, R>( - &self, - cid: u16, - f: F, - ) -> Result { + fn handle_connect_request(&self, conn: ConnHandle, identifier: u8, req: &LeCreditConnReq) -> Result<(), Error> { + self.alloc(|storage| { + storage.conn = conn.raw(); + storage.psm = req.psm; + storage.peer_cid = req.scid; + storage.peer_credits = req.credits; + storage.mps = req.mps; + storage.mtu = req.mtu; + storage.state = ChannelState::PeerConnecting(identifier); + })?; + self.state.lock(|state| { + state.borrow_mut().accept_waker.wake(); + }); + Ok(()) + } + + fn handle_connect_response(&self, conn: ConnHandle, identifier: u8, res: &LeCreditConnRes) -> Result<(), Error> { + match res.result { + LeCreditConnResultCode::Success => { + // Must be a response of a previous request which should already by allocated a channel for + self.state.lock(|state| { + let mut state = state.borrow_mut(); + for storage in state.channels.iter_mut() { + match storage.state { + ChannelState::Connecting(req_id) if identifier == req_id && conn.raw() == storage.conn => { + storage.peer_cid = res.dcid; + storage.peer_credits = res.credits; + storage.mps = storage.mps.min(res.mps); + storage.mtu = storage.mtu.min(res.mtu); + storage.state = ChannelState::Connected; + state.create_waker.wake(); + return Ok(()); + } + _ => {} + } + } + Err(Error::NotFound) + }) + } + other => { + warn!("Channel open request failed: {:?}", other); + Err(Error::NotSupported) + } + } + } + + fn handle_credit_flow(&self, req: &LeCreditFlowInd) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); - for (idx, chan) in state.channels.iter_mut().enumerate() { - match chan { - ChannelState::Connected(state) if state.cid == cid => { - return Ok(f(idx, state)); + for (idx, storage) in state.channels.iter_mut().enumerate() { + match storage.state { + ChannelState::Connected if storage.peer_cid == req.cid => { + storage.peer_credits += req.credits; + state.credit_wakers[idx].wake(); + return Ok(()); } _ => {} } @@ -530,50 +422,65 @@ impl< }) } - async fn receive_pdu(&self, cid: u16, idx: usize) -> Result, Error> { - match self.inbound[idx].receive().await { - Some(pdu) => Ok(pdu), - None => { - self.confirm_disconnected(cid)?; - Err(Error::ChannelClosed) + fn handle_disconnect_response(&self, res: &DisconnectionRes) -> Result<(), Error> { + let cid = res.dcid; + self.state.lock(|state| { + let mut state = state.borrow_mut(); + for storage in state.channels.iter_mut() { + match storage.state { + ChannelState::Disconnecting if cid == storage.cid => { + storage.state = ChannelState::Disconnected; + break; + } + ChannelState::PeerConnecting(_) if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; + break; + } + ChannelState::Connecting(_) if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; + break; + } + ChannelState::Connected if cid == storage.cid => { + storage.state = ChannelState::Disconnecting; + break; + } + _ => {} + } } - } + Ok(()) + }) } + /// Receive data on a given channel and copy it into the buffer. + /// + /// The length provided buffer slice must be equal or greater to the agreed MTU. pub(crate) async fn receive( &self, cid: u16, buf: &mut [u8], hci: &HciController<'_, T>, ) -> Result> { - let idx = self.with_connected_channel(cid, |idx, _state| idx)?; + let idx = self.connected_channel_index(cid)?; + let mut n_received = 1; - let packet = self.receive_pdu(cid, idx).await?; + let packet = self.receive_pdu(cid, idx, hci).await?; let len = packet.len; let mut r = ReadCursor::new(packet.as_ref()); let remaining: u16 = r.read()?; - // info!("Total expected: {}", remaining); let data = r.remaining(); let to_copy = data.len().min(buf.len()); buf[..to_copy].copy_from_slice(&data[..to_copy]); let mut pos = to_copy; - // info!("Received {} bytes so far", pos); - let mut remaining = remaining as usize - data.len(); self.flow_control(cid, hci, packet.packet).await?; - //info!( - // "Total size of PDU is {}, read buffer size is {} remaining; {}", - // len, - // buf.len(), - // remaining - //); + // We have some k-frames to reassemble while remaining > 0 { - let packet = self.receive_pdu(cid, idx).await?; + let packet = self.receive_pdu(cid, idx, hci).await?; n_received += 1; let to_copy = packet.len.min(buf.len() - pos); if to_copy > 0 { @@ -584,10 +491,42 @@ impl< self.flow_control(cid, hci, packet.packet).await?; } - // info!("Total reserved {} bytes", pos); Ok(pos) } + // Return the array index for a given active channel + fn connected_channel_index(&self, cid: u16) -> Result { + self.state.lock(|state| { + let state = state.borrow(); + for (idx, chan) in state.channels.iter().enumerate() { + if chan.cid == cid && chan.state == ChannelState::Connected { + return Ok(idx); + } + } + Err(Error::NotFound) + }) + } + + async fn receive_pdu( + &self, + cid: u16, + idx: usize, + hci: &HciController<'_, T>, + ) -> Result, AdapterError> { + match self.inbound[idx].receive().await { + Some(pdu) => Ok(pdu), + None => { + self.confirm_disconnected(cid, hci).await?; + Err(Error::ChannelClosed.into()) + } + } + } + + /// Send the provided buffer over a given l2cap channel. + /// + /// The buffer will be segmented to the maximum payload size agreed in the opening handshake. + /// + /// If the channel has been closed or the channel id is not valid, an error is returned. pub(crate) async fn send( &self, cid: u16, @@ -595,11 +534,9 @@ impl< hci: &HciController<'_, T>, ) -> Result<(), AdapterError> { let mut p_buf = [0u8; L2CAP_MTU]; - let (conn, mps, peer_cid) = - self.with_connected_channel(cid, |_, state| (state.conn, state.mps, state.peer_cid))?; + let (conn, mps, peer_cid) = self.connected_channel_params(cid)?; // The number of packets we'll need to send for this payload let n_packets = 1 + ((buf.len() as u16).saturating_sub(mps - 2)).div_ceil(mps); - // info!("Sending data of len {} into {} packets", buf.len(), n_packets); poll_fn(|cx| self.poll_request_to_send(cid, n_packets, Some(cx))).await?; @@ -620,6 +557,11 @@ impl< Ok(()) } + /// Send the provided buffer over a given l2cap channel. + /// + /// The buffer will be segmented to the maximum payload size agreed in the opening handshake. + /// + /// If the channel has been closed or the channel id is not valid, an error is returned. pub(crate) fn try_send( &self, cid: u16, @@ -627,11 +569,10 @@ impl< hci: &HciController<'_, T>, ) -> Result<(), AdapterError> { let mut p_buf = [0u8; L2CAP_MTU]; - let (conn, mps, peer_cid) = - self.with_connected_channel(cid, |_, state| (state.conn, state.mps, state.peer_cid))?; + let (conn, mps, peer_cid) = self.connected_channel_params(cid)?; + // The number of packets we'll need to send for this payload let n_packets = 1 + ((buf.len() as u16).saturating_sub(mps - 2)).div_ceil(mps); - // info!("Sending data of len {} into {} packets", buf.len(), n_packets); match self.poll_request_to_send(cid, n_packets, None) { Poll::Ready(res) => res?, @@ -658,6 +599,23 @@ impl< Ok(()) } + fn connected_channel_params(&self, cid: u16) -> Result<(ConnHandle, u16, u16), Error> { + self.state.lock(|state| { + let state = state.borrow(); + for chan in state.channels.iter() { + match chan.state { + ChannelState::Connected if chan.cid == cid => { + return Ok((ConnHandle::new(chan.conn), chan.mps, chan.peer_cid)); + } + _ => {} + } + } + Err(Error::NotFound) + }) + } + + // Check the current state of flow control and send flow indications if + // our policy says so. async fn flow_control( &self, cid: u16, @@ -666,10 +624,10 @@ impl< ) -> Result<(), AdapterError> { let (conn, credits) = self.state.lock(|state| { let mut state = state.borrow_mut(); - for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connected(state) if cid == state.cid => { - return Ok((state.conn, state.flow_control.process())); + for storage in state.channels.iter_mut() { + match storage.state { + ChannelState::Connected if cid == storage.cid => { + return Ok((storage.conn, storage.flow_control.process())); } _ => {} } @@ -682,35 +640,57 @@ impl< let signal = LeCreditFlowInd { cid, credits }; // Reuse packet buffer for signalling data to save the extra TX buffer - hci.signal(conn, identifier, &signal, packet.as_mut()).await?; + hci.signal(ConnHandle::new(conn), identifier, &signal, packet.as_mut()) + .await?; } Ok(()) } - fn confirm_disconnected(&self, cid: u16) -> Result<(), Error> { - self.state.lock(|state| { + async fn confirm_disconnected( + &self, + cid: u16, + hci: &HciController<'_, T>, + ) -> Result<(), AdapterError> { + let (handle, dcid, scid) = self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.channels.iter_mut() { - match storage { - ChannelState::Disconnecting(state) if cid == state.cid => { - *storage = ChannelState::Disconnected; - return Ok(()); + match storage.state { + ChannelState::Disconnecting if cid == storage.cid => { + storage.state = ChannelState::Disconnected; + let scid = storage.cid; + let dcid = storage.peer_cid; + let handle = storage.conn; + storage.cid = 0; + storage.peer_cid = 0; + storage.conn = 0; + return Ok((handle, dcid, scid)); } _ => {} } } Err(Error::NotFound) - }) + })?; + + let identifier = self.next_request_id(); + let mut tx = [0; 18]; + hci.signal( + ConnHandle::new(handle), + identifier, + &DisconnectionRes { dcid, scid }, + &mut tx[..], + ) + .await?; + Ok(()) } fn poll_request_to_send(&self, cid: u16, credits: u16, cx: Option<&mut Context<'_>>) -> Poll> { self.state.lock(|state| { let mut state = state.borrow_mut(); for (idx, storage) in state.channels.iter_mut().enumerate() { - match storage { - ChannelState::Connected(s) if cid == s.cid => { - if credits <= s.peer_credits { - s.peer_credits -= credits; + match storage.state { + ChannelState::Connected if cid == storage.cid => { + if credits <= storage.peer_credits { + storage.peer_credits -= credits; return Poll::Ready(Ok(())); } else { if let Some(cx) = cx { @@ -744,12 +724,41 @@ fn encode(data: &[u8], packet: &mut [u8], peer_cid: u16, header: Option) -> Ok(w.len()) } +pub struct ChannelStorage { + state: ChannelState, + conn: u16, + cid: u16, + psm: u16, + mps: u16, + mtu: u16, + flow_control: CreditFlowControl, + + peer_cid: u16, + peer_credits: u16, +} + +impl ChannelStorage { + const DISCONNECTED: ChannelStorage = ChannelStorage { + state: ChannelState::Disconnected, + conn: 0, + cid: 0, + mps: 0, + mtu: 0, + psm: 0, + + flow_control: CreditFlowControl::new(CreditFlowPolicy::Every(1), 0), + peer_cid: 0, + peer_credits: 0, + }; +} + +#[derive(PartialEq)] pub enum ChannelState { Disconnected, - Connecting(ConnectingState), - PeerConnecting(PeerConnectingState), - Connected(ConnectedState), - Disconnecting(DisconnectingState), + Connecting(u8), + PeerConnecting(u8), + Connected, + Disconnecting, } /// Control how credits are issued by the receiving end. @@ -768,7 +777,7 @@ impl Default for CreditFlowPolicy { } } -#[derive(Clone, Debug)] +#[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub(crate) struct CreditFlowControl { policy: CreditFlowPolicy, @@ -777,14 +786,13 @@ pub(crate) struct CreditFlowControl { } impl CreditFlowControl { - fn new(policy: CreditFlowPolicy, initial_credits: u16) -> Self { + const fn new(policy: CreditFlowPolicy, initial_credits: u16) -> Self { Self { policy, credits: initial_credits, received: 0, } } - fn available(&self) -> u16 { self.credits } @@ -819,54 +827,3 @@ impl CreditFlowControl { } } } - -#[derive(Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct ConnectingState { - pub(crate) conn: ConnHandle, - pub(crate) cid: u16, - pub(crate) request_id: u8, - pub(crate) flow_control_policy: CreditFlowPolicy, - - pub(crate) psm: u16, - pub(crate) initial_credits: u16, - pub(crate) mps: u16, - pub(crate) mtu: u16, -} - -#[derive(Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct PeerConnectingState { - pub(crate) conn: ConnHandle, - pub(crate) cid: u16, - pub(crate) request_id: u8, - - pub(crate) psm: u16, - pub(crate) peer_cid: u16, - pub(crate) offered_credits: u16, - pub(crate) mps: u16, - pub(crate) mtu: u16, -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct ConnectedState { - pub(crate) conn: ConnHandle, - pub(crate) cid: u16, - pub(crate) psm: u16, - pub(crate) mps: u16, - pub(crate) mtu: u16, - pub(crate) flow_control: CreditFlowControl, - - pub(crate) peer_cid: u16, - pub(crate) peer_credits: u16, - - pub(crate) pool_id: AllocId, -} - -#[derive(Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct DisconnectingState { - pub(crate) conn: ConnHandle, - pub(crate) cid: u16, -} diff --git a/host/src/connection.rs b/host/src/connection.rs index efb37b6c..1b21c19e 100644 --- a/host/src/connection.rs +++ b/host/src/connection.rs @@ -7,7 +7,6 @@ use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_time::Duration; use crate::adapter::Adapter; -pub use crate::connection_manager::ConnectionInfo; use crate::scan::ScanConfig; use crate::AdapterError; @@ -50,7 +49,7 @@ impl Connection { self.handle } - pub async fn disconnect< + pub fn disconnect< M: RawMutex, T: Controller + ControllerCmdSync, const CONNS: usize, @@ -62,9 +61,7 @@ impl Connection { &mut self, adapter: &Adapter<'_, M, T, CONNS, CHANNELS, L2CAP_MTU, L2CAP_TXQ, L2CAP_RXQ>, ) -> Result<(), AdapterError> { - adapter - .command(Disconnect::new(self.handle, DisconnectReason::RemoteUserTerminatedConn)) - .await?; + adapter.try_command(Disconnect::new(self.handle, DisconnectReason::RemoteUserTerminatedConn))?; Ok(()) } @@ -80,7 +77,7 @@ impl Connection { &self, adapter: &Adapter<'_, M, T, CONNS, CHANNELS, L2CAP_MTU, L2CAP_TXQ, L2CAP_RXQ>, ) -> Result> { - let role = adapter.connections.with_connection(self.handle, |info| info.role)?; + let role = adapter.connections.role(self.handle)?; Ok(role) } @@ -96,10 +93,8 @@ impl Connection { &self, adapter: &Adapter<'_, M, T, CONNS, CHANNELS, L2CAP_MTU, L2CAP_TXQ, L2CAP_RXQ>, ) -> Result> { - let role = adapter - .connections - .with_connection(self.handle, |info| info.peer_address)?; - Ok(role) + let addr = adapter.connections.peer_address(self.handle)?; + Ok(addr) } pub async fn rssi< diff --git a/host/src/connection_manager.rs b/host/src/connection_manager.rs index b37e30c8..99e723d4 100644 --- a/host/src/connection_manager.rs +++ b/host/src/connection_manager.rs @@ -2,7 +2,8 @@ use core::cell::RefCell; use core::future::poll_fn; use core::task::{Context, Poll}; -use bt_hci::param::{AddrKind, BdAddr, ConnHandle, LeConnRole, Status}; +use bt_hci::event::le::LeConnectionComplete; +use bt_hci::param::{AddrKind, BdAddr, ConnHandle, LeConnRole}; use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::blocking_mutex::Mutex; use embassy_sync::signal::Signal; @@ -11,68 +12,78 @@ use embassy_sync::waitqueue::WakerRegistration; use crate::Error; struct State { - connections: [ConnectionState; CONNS], + connections: [ConnectionStorage; CONNS], waker: WakerRegistration, } -pub struct ConnectionManager { +pub(crate) struct ConnectionManager { state: Mutex>>, canceled: Signal, } impl ConnectionManager { - const DISCONNECTED: ConnectionState = ConnectionState::Disconnected; - pub fn new() -> Self { + pub(crate) fn new() -> Self { Self { state: Mutex::new(RefCell::new(State { - connections: [Self::DISCONNECTED; CONNS], + connections: [ConnectionStorage::DISCONNECTED; CONNS], waker: WakerRegistration::new(), })), canceled: Signal::new(), } } - pub fn disconnect(&self, h: ConnHandle) -> Result<(), Error> { + pub(crate) fn role(&self, h: ConnHandle) -> Result { self.state.lock(|state| { - let mut state = state.borrow_mut(); - for storage in state.connections.iter_mut() { - match storage { - ConnectionState::Connecting(handle, _) if *handle == h => { - *storage = ConnectionState::Disconnected; - } - ConnectionState::Connected(handle, _) if *handle == h => { - *storage = ConnectionState::Disconnected; - } - _ => {} + let state = state.borrow(); + for storage in state.connections.iter() { + if storage.state == ConnectionState::Connected && storage.handle.unwrap() == h { + return Ok(storage.role.unwrap()); } } - Ok(()) + Err(Error::NotFound) }) } - pub(crate) fn with_connection R, R>( - &self, - handle: ConnHandle, - f: F, - ) -> Result { + pub(crate) fn peer_address(&self, h: ConnHandle) -> Result { self.state.lock(|state| { let state = state.borrow(); for storage in state.connections.iter() { - match storage { - ConnectionState::Connected(h, info) if *h == handle => return Ok(f(info)), - _ => {} + if storage.state == ConnectionState::Connected && storage.handle.unwrap() == h { + return Ok(storage.peer_addr.unwrap()); } } Err(Error::NotFound) }) } - pub fn connect(&self, handle: ConnHandle, info: ConnectionInfo) -> Result<(), Error> { + pub(crate) fn disconnect(&self, h: ConnHandle) -> Result<(), Error> { self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.connections.iter_mut() { - if let ConnectionState::Disconnected = storage { - *storage = ConnectionState::Connecting(handle, Some(info)); + match storage.state { + ConnectionState::Connecting if storage.handle.unwrap() == h => { + storage.state = ConnectionState::Disconnected; + } + ConnectionState::Connected if storage.handle.unwrap() == h => { + storage.state = ConnectionState::Disconnected; + } + _ => {} + } + } + Ok(()) + }) + } + + pub(crate) fn connect(&self, handle: ConnHandle, info: &LeConnectionComplete) -> Result<(), Error> { + self.state.lock(|state| { + let mut state = state.borrow_mut(); + for storage in state.connections.iter_mut() { + if let ConnectionState::Disconnected = storage.state { + storage.state = ConnectionState::Connecting; + storage.handle.replace(handle); + storage.peer_addr_kind.replace(info.peer_addr_kind); + storage.peer_addr.replace(info.peer_addr); + storage.role.replace(info.role); state.waker.wake(); return Ok(()); } @@ -81,32 +92,30 @@ impl ConnectionManager { }) } - pub async fn wait_canceled(&self) { + pub(crate) async fn wait_canceled(&self) { self.canceled.wait().await; self.canceled.reset(); } - pub fn canceled(&self) { + pub(crate) fn canceled(&self) { self.canceled.signal(()); } - pub fn poll_accept(&self, peers: &[(AddrKind, &BdAddr)], cx: &mut Context<'_>) -> Poll { + pub(crate) fn poll_accept(&self, peers: &[(AddrKind, &BdAddr)], cx: &mut Context<'_>) -> Poll { self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.connections.iter_mut() { - if let ConnectionState::Connecting(handle, info) = storage { - let handle = *handle; + if let ConnectionState::Connecting = storage.state { + let handle = storage.handle.unwrap(); if !peers.is_empty() { for peer in peers.iter() { - if info.as_ref().unwrap().peer_addr_kind == peer.0 - && &info.as_ref().unwrap().peer_address == peer.1 - { - *storage = ConnectionState::Connected(handle, info.take().unwrap()); + if storage.peer_addr_kind.unwrap() == peer.0 && &storage.peer_addr.unwrap() == peer.1 { + storage.state = ConnectionState::Connected; return Poll::Ready(handle); } } } else { - *storage = ConnectionState::Connected(handle, info.take().unwrap()); + storage.state = ConnectionState::Connected; return Poll::Ready(handle); } } @@ -116,17 +125,11 @@ impl ConnectionManager { }) } - pub async fn accept(&self, peers: &[(AddrKind, &BdAddr)]) -> ConnHandle { + pub(crate) async fn accept(&self, peers: &[(AddrKind, &BdAddr)]) -> ConnHandle { poll_fn(move |cx| self.poll_accept(peers, cx)).await } } -pub enum ConnectionState { - Disconnected, - Connecting(ConnHandle, Option), - Connected(ConnHandle, ConnectionInfo), -} - pub trait DynamicConnectionManager { fn get_att_mtu(&self, conn: ConnHandle) -> u16; fn exchange_att_mtu(&self, conn: ConnHandle, mtu: u16) -> u16; @@ -137,9 +140,9 @@ impl DynamicConnectionManager for ConnectionMan self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.connections.iter_mut() { - match storage { - ConnectionState::Connected(handle, info) if *handle == conn => { - return info.att_mtu; + match storage.state { + ConnectionState::Connected if storage.handle.unwrap() == conn => { + return storage.att_mtu; } _ => {} } @@ -151,10 +154,10 @@ impl DynamicConnectionManager for ConnectionMan self.state.lock(|state| { let mut state = state.borrow_mut(); for storage in state.connections.iter_mut() { - match storage { - ConnectionState::Connected(handle, info) if *handle == conn => { - info.att_mtu = info.att_mtu.min(mtu); - return info.att_mtu; + match storage.state { + ConnectionState::Connected if storage.handle.unwrap() == conn => { + storage.att_mtu = storage.att_mtu.min(mtu); + return storage.att_mtu; } _ => {} } @@ -166,14 +169,30 @@ impl DynamicConnectionManager for ConnectionMan #[derive(Debug)] #[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct ConnectionInfo { - pub handle: ConnHandle, - pub status: Status, - pub role: LeConnRole, - pub peer_addr_kind: AddrKind, - pub peer_address: BdAddr, - pub interval: u16, - pub latency: u16, - pub timeout: u16, +pub struct ConnectionStorage { + pub state: ConnectionState, + pub handle: Option, + pub role: Option, + pub peer_addr_kind: Option, + pub peer_addr: Option, pub att_mtu: u16, } + +impl ConnectionStorage { + const DISCONNECTED: ConnectionStorage = ConnectionStorage { + state: ConnectionState::Disconnected, + handle: None, + role: None, + peer_addr_kind: None, + peer_addr: None, + att_mtu: 23, + }; +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ConnectionState { + Disconnected, + Connecting, + Connected, +} diff --git a/host/src/l2cap.rs b/host/src/l2cap.rs index 43cde874..4bceb3ee 100644 --- a/host/src/l2cap.rs +++ b/host/src/l2cap.rs @@ -1,6 +1,6 @@ use bt_hci::cmd::link_control::Disconnect; use bt_hci::controller::{Controller, ControllerCmdSync}; -use bt_hci::param::{ConnHandle, DisconnectReason}; +use bt_hci::param::DisconnectReason; use embassy_sync::blocking_mutex::raw::RawMutex; use crate::adapter::Adapter; @@ -13,7 +13,6 @@ pub(crate) mod sar; /// Handle representing an L2CAP channel. #[derive(Clone)] pub struct L2capChannel { - handle: ConnHandle, cid: u16, } @@ -39,6 +38,12 @@ impl Default for L2capChannelConfig { } impl L2capChannel { + /// Send the provided buffer over this l2cap channel. + /// + /// The buffer will be segmented to the maximum payload size agreed in the opening handshake. + /// + /// If the channel has been closed or the channel id is not valid, an error is returned. + /// If there are no available credits to send, waits until more credits are available. pub async fn send< M: RawMutex, T: Controller, @@ -55,6 +60,12 @@ impl L2capChannel { adapter.channels.send(self.cid, buf, &adapter.hci()).await } + /// Send the provided buffer over this l2cap channel. + /// + /// The buffer will be segmented to the maximum payload size agreed in the opening handshake. + /// + /// If the channel has been closed or the channel id is not valid, an error is returned. + /// If there are no available credits to send, returns Error::Busy. pub fn try_send< M: RawMutex, T: Controller, @@ -71,6 +82,9 @@ impl L2capChannel { adapter.channels.try_send(self.cid, buf, &adapter.hci()) } + /// Receive data on this channel and copy it into the buffer. + /// + /// The length provided buffer slice must be equal or greater to the agreed MTU. pub async fn receive< M: RawMutex, T: Controller, @@ -87,6 +101,7 @@ impl L2capChannel { adapter.channels.receive(self.cid, buf, &adapter.hci()).await } + /// Await an incoming connection request matching the list of PSM. pub async fn accept< M: RawMutex, T: Controller, @@ -114,9 +129,10 @@ impl L2capChannel { ) .await?; - Ok(Self { cid, handle }) + Ok(Self { cid }) } + /// Disconnect this channel. pub fn disconnect< M: RawMutex, T: Controller + ControllerCmdSync, @@ -130,13 +146,14 @@ impl L2capChannel { adapter: &Adapter<'_, M, T, CONNS, CHANNELS, L2CAP_MTU, L2CAP_TXQ, L2CAP_RXQ>, close_connection: bool, ) -> Result<(), AdapterError> { - adapter.channels.disconnect(self.cid)?; + let handle = adapter.channels.disconnect(self.cid)?; if close_connection { - adapter.try_command(Disconnect::new(self.handle, DisconnectReason::RemoteUserTerminatedConn))?; + adapter.try_command(Disconnect::new(handle, DisconnectReason::RemoteUserTerminatedConn))?; } Ok(()) } + /// Create a new connection request with the provided PSM. pub async fn create< M: RawMutex, T: Controller, @@ -165,6 +182,6 @@ where { ) .await?; - Ok(Self { handle, cid }) + Ok(Self { cid }) } } diff --git a/host/src/types/l2cap.rs b/host/src/types/l2cap.rs index 7ba9452d..6942ad3f 100644 --- a/host/src/types/l2cap.rs +++ b/host/src/types/l2cap.rs @@ -219,3 +219,9 @@ unsafe impl FixedSizeValue for DisconnectionRes { return true; } } + +impl L2capSignal for DisconnectionRes { + fn code() -> L2capSignalCode { + L2capSignalCode::DisconnectionRes + } +}