diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 557c0ba4..0a7a32e9 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -15,6 +15,7 @@ jobs: - uses: actions/checkout@v4 - name: Add dependencies run: | + sudo apt-get update sudo apt-get install --no-install-recommends libudev-dev - name: Set up cargo cache uses: actions/cache@v3 diff --git a/.github/workflows/release_trouble_host.yaml b/.github/workflows/release_trouble_host.yaml index 0bf6eebb..26be513f 100644 --- a/.github/workflows/release_trouble_host.yaml +++ b/.github/workflows/release_trouble_host.yaml @@ -16,6 +16,7 @@ jobs: - uses: actions/checkout@v4 - name: Add dependencies run: | + sudo apt-get update sudo apt-get install --no-install-recommends libudev-dev - name: Set up cargo cache uses: actions/cache@v3 diff --git a/.github/workflows/release_trouble_host_macros.yaml b/.github/workflows/release_trouble_host_macros.yaml index 2080fcbb..7691aeab 100644 --- a/.github/workflows/release_trouble_host_macros.yaml +++ b/.github/workflows/release_trouble_host_macros.yaml @@ -16,6 +16,7 @@ jobs: - uses: actions/checkout@v4 - name: Add dependencies run: | + sudo apt-get update sudo apt-get install --no-install-recommends libudev-dev - name: Set up cargo cache uses: actions/cache@v3 diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 0f6e3cbb..17cba92d 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -20,6 +20,7 @@ jobs: gh pr checkout "${{ github.event.inputs.prNr }}" - name: Add dependencies run: | + sudo apt-get update sudo apt-get install --no-install-recommends libudev-dev - name: Set up cargo cache uses: actions/cache@v3 diff --git a/host/src/att.rs b/host/src/att.rs index fc11d083..120c1d06 100644 --- a/host/src/att.rs +++ b/host/src/att.rs @@ -1,3 +1,4 @@ +//! Attribute Protocol (ATT) PDU definitions use core::fmt::Display; use core::mem; @@ -30,6 +31,8 @@ pub(crate) const ATT_READ_MULTIPLE_RSP: u8 = 0x21; pub(crate) const ATT_READ_BLOB_REQ: u8 = 0x0c; pub(crate) const ATT_READ_BLOB_RSP: u8 = 0x0d; pub(crate) const ATT_HANDLE_VALUE_NTF: u8 = 0x1b; +pub(crate) const ATT_HANDLE_VALUE_IND: u8 = 0x1d; +pub(crate) const ATT_HANDLE_VALUE_CMF: u8 = 0x1e; /// Attribute Error Code /// @@ -155,108 +158,208 @@ impl codec::Type for AttErrorCode { } } +/// ATT Client PDU (Request, Command, Confirmation) +/// +/// The ATT Client PDU is used to send requests, commands and confirmations to the ATT Server +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug)] +pub enum AttClient<'d> { + /// ATT Request PDU + Request(AttReq<'d>), + /// ATT Command PDU + Command(AttCmd<'d>), + /// ATT Confirmation PDU + Confirmation(AttCfm), +} + +/// ATT Request PDU #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[derive(Debug)] pub enum AttReq<'d> { + /// Read By Group Type Request ReadByGroupType { + /// Start attribute handle start: u16, + /// End attribute handle end: u16, + /// Group type group_type: Uuid, }, + /// Read By Type Request ReadByType { + /// Start attribute handle start: u16, + /// End attribute handle end: u16, + /// Attribute type attribute_type: Uuid, }, + /// Read Request Read { + /// Attribute handle handle: u16, }, + /// Write Request Write { + /// Attribute handle handle: u16, + /// Attribute value data: &'d [u8], }, - WriteCmd { - handle: u16, - data: &'d [u8], - }, + /// Exchange MTU Request ExchangeMtu { + /// Client MTU mtu: u16, }, + /// Find By Type Value Request FindByTypeValue { + /// Start attribute handle start_handle: u16, + /// End attribute handle end_handle: u16, + /// Attribute type att_type: u16, + /// Attribute value att_value: &'d [u8], }, + /// Find Information Request FindInformation { + /// Start attribute handle start_handle: u16, + /// End attribute handle end_handle: u16, }, + /// Prepare Write Request PrepareWrite { + /// Attribute handle handle: u16, + /// Attribute offset offset: u16, + /// Attribute value value: &'d [u8], }, + /// Execute Write Request ExecuteWrite { + /// Flags flags: u8, }, + /// Read Multiple Request ReadMultiple { + /// Attribute handles handles: &'d [u8], }, + /// Read Blob Request ReadBlob { + /// Attribute handle handle: u16, + /// Attribute offset offset: u16, }, } +/// ATT Command PDU +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug)] +pub enum AttCmd<'d> { + /// Write Command + Write { + /// Attribute handle + handle: u16, + /// Attribute value + data: &'d [u8], + }, +} + +/// ATT Confirmation PDU +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug)] +pub enum AttCfm { + /// Confirm Indication + ConfirmIndication, +} + +/// ATT Server PDU (Response, Unsolicited) +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug)] +pub enum AttServer<'d> { + /// ATT Response PDU + Response(AttRsp<'d>), + /// ATT Unsolicited PDU + Unsolicited(AttUns<'d>), +} + +/// ATT Response PDU #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[derive(Debug)] pub enum AttRsp<'d> { + /// Exchange MTU Response ExchangeMtu { + /// Server MTU mtu: u16, }, + /// Find By Type Value Response FindByTypeValue { + /// Iterator over the found handles it: FindByTypeValueIter<'d>, }, + /// Error Response Error { + /// Request opcode request: u8, + /// Attribute handle handle: u16, + /// Error code code: AttErrorCode, }, + /// Read Response ReadByType { + /// Iterator over the found handles it: ReadByTypeIter<'d>, }, + /// Read Response Read { + /// Attribute value data: &'d [u8], }, + /// Write Response Write, } +/// ATT Unsolicited PDU #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[derive(Debug)] -pub enum Att<'d> { - Req(AttReq<'d>), - Rsp(AttRsp<'d>), -} - -impl codec::Type for AttRsp<'_> { - fn size(&self) -> usize { - AttRsp::size(self) - } -} - -impl codec::Encode for AttRsp<'_> { - fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { - AttRsp::encode(self, dest) - } +pub enum AttUns<'d> { + /// Notify + Notify { + /// Attribute handle + handle: u16, + /// Attribute value + data: &'d [u8], + }, + /// Indicate + Indicate { + /// Attribute handle + handle: u16, + /// Attribute value + data: &'d [u8], + }, } -impl<'d> codec::Decode<'d> for AttRsp<'d> { - fn decode(src: &'d [u8]) -> Result, codec::Error> { - AttRsp::decode(src) - } +/// ATT Protocol Data Unit (PDU) +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug)] +pub enum Att<'d> { + /// ATT Client PDU (Request, Command, Confirmation) + /// + /// The ATT Client PDU is used to send requests, commands and confirmations to the ATT Server + Client(AttClient<'d>), + /// ATT Server PDU (Response, Unsolicited) + /// + /// The ATT Server PDU is used to send responses and unsolicited ATT PDUs (notifications and indications) to the ATT Client + Server(AttServer<'d>), } +/// An Iterator-like type for iterating over the found handles #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[derive(Clone, Debug)] pub struct FindByTypeValueIter<'d> { @@ -264,6 +367,8 @@ pub struct FindByTypeValueIter<'d> { } impl FindByTypeValueIter<'_> { + /// Get the next pair of start and end attribute handles + #[allow(clippy::should_implement_trait)] pub fn next(&mut self) -> Option> { if self.cursor.available() >= 4 { let res = (|| { @@ -278,6 +383,7 @@ impl FindByTypeValueIter<'_> { } } +/// An Iterator-like type for iterating over the found handles #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[derive(Clone, Debug)] pub struct ReadByTypeIter<'d> { @@ -286,6 +392,8 @@ pub struct ReadByTypeIter<'d> { } impl<'d> ReadByTypeIter<'d> { + /// Get the next pair of attribute handle and attribute data + #[allow(clippy::should_implement_trait)] pub fn next(&mut self) -> Option> { if self.cursor.available() >= self.item_len { let res = (|| { @@ -300,8 +408,32 @@ impl<'d> ReadByTypeIter<'d> { } } +impl<'d> AttServer<'d> { + fn size(&self) -> usize { + match self { + Self::Response(rsp) => rsp.size(), + Self::Unsolicited(uns) => uns.size(), + } + } + + fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { + match self { + Self::Response(rsp) => rsp.encode(dest), + Self::Unsolicited(uns) => uns.encode(dest), + } + } + + fn decode_with_opcode(opcode: u8, r: ReadCursor<'d>) -> Result { + let decoded = match opcode { + ATT_HANDLE_VALUE_NTF | ATT_HANDLE_VALUE_IND => Self::Unsolicited(AttUns::decode_with_opcode(opcode, r)?), + _ => Self::Response(AttRsp::decode_with_opcode(opcode, r)?), + }; + Ok(decoded) + } +} + impl<'d> AttRsp<'d> { - pub fn size(&self) -> usize { + fn size(&self) -> usize { 1 + match self { Self::ExchangeMtu { mtu: u16 } => 2, Self::FindByTypeValue { it } => it.cursor.len(), @@ -312,7 +444,7 @@ impl<'d> AttRsp<'d> { } } - pub fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { + fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { let mut w = WriteCursor::new(dest); match self { Self::ExchangeMtu { mtu } => { @@ -353,13 +485,7 @@ impl<'d> AttRsp<'d> { Ok(()) } - pub fn decode(data: &'d [u8]) -> Result, codec::Error> { - let mut r = ReadCursor::new(data); - let opcode: u8 = r.read()?; - AttRsp::decode_with_opcode(opcode, r) - } - - pub fn decode_with_opcode(opcode: u8, mut r: ReadCursor<'d>) -> Result, codec::Error> { + fn decode_with_opcode(opcode: u8, mut r: ReadCursor<'d>) -> Result { match opcode { ATT_FIND_BY_TYPE_VALUE_RSP => Ok(Self::FindByTypeValue { it: FindByTypeValueIter { cursor: r }, @@ -390,46 +516,81 @@ impl<'d> AttRsp<'d> { } } -impl From for AttErrorCode { - fn from(e: codec::Error) -> Self { - AttErrorCode::INVALID_PDU - } -} - -impl codec::Type for AttReq<'_> { +impl<'d> AttUns<'d> { fn size(&self) -> usize { - AttReq::size(self) + 1 + match self { + Self::Notify { data, .. } => 2 + data.len(), + Self::Indicate { data, .. } => 2 + data.len(), + } } -} -impl codec::Encode for AttReq<'_> { fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { - AttReq::encode(self, dest) + let mut w = WriteCursor::new(dest); + match self { + Self::Notify { handle, data } => { + w.write(ATT_HANDLE_VALUE_NTF)?; + w.write(*handle)?; + w.append(data)?; + } + Self::Indicate { handle, data } => { + w.write(ATT_HANDLE_VALUE_IND)?; + w.write(*handle)?; + w.append(data)?; + } + } + Ok(()) } -} -impl<'d> codec::Decode<'d> for AttReq<'d> { - fn decode(data: &'d [u8]) -> Result, codec::Error> { - AttReq::decode(data) + fn decode_with_opcode(opcode: u8, mut r: ReadCursor<'d>) -> Result { + match opcode { + ATT_HANDLE_VALUE_NTF => { + let handle = r.read()?; + Ok(Self::Notify { + handle, + data: r.remaining(), + }) + } + ATT_HANDLE_VALUE_IND => { + let handle = r.read()?; + Ok(Self::Indicate { + handle, + data: r.remaining(), + }) + } + _ => Err(codec::Error::InvalidValue), + } } } -impl<'d> Att<'d> { - pub fn decode(data: &'d [u8]) -> Result, codec::Error> { - let mut r = ReadCursor::new(data); - let opcode: u8 = r.read()?; - if opcode % 2 == 0 { - let req = AttReq::decode_with_opcode(opcode, r)?; - Ok(Att::Req(req)) - } else { - let rsp = AttRsp::decode_with_opcode(opcode, r)?; - Ok(Att::Rsp(rsp)) +impl<'d> AttClient<'d> { + fn size(&self) -> usize { + match self { + Self::Request(req) => req.size(), + Self::Command(cmd) => cmd.size(), + Self::Confirmation(cfm) => cfm.size(), + } + } + + fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { + match self { + Self::Request(req) => req.encode(dest), + Self::Command(cmd) => cmd.encode(dest), + Self::Confirmation(cfm) => cfm.encode(dest), } } + + fn decode_with_opcode(opcode: u8, r: ReadCursor<'d>) -> Result { + let decoded = match opcode { + ATT_WRITE_CMD => Self::Command(AttCmd::decode_with_opcode(opcode, r)?), + ATT_HANDLE_VALUE_CMF => Self::Confirmation(AttCfm::decode_with_opcode(opcode, r)?), + _ => Self::Request(AttReq::decode_with_opcode(opcode, r)?), + }; + Ok(decoded) + } } impl<'d> AttReq<'d> { - pub fn size(&self) -> usize { + fn size(&self) -> usize { 1 + match self { Self::ExchangeMtu { .. } => 2, Self::FindByTypeValue { @@ -448,7 +609,7 @@ impl<'d> AttReq<'d> { _ => unimplemented!(), } } - pub fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { + fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { let mut w = WriteCursor::new(dest); match self { Self::ExchangeMtu { mtu } => { @@ -491,13 +652,7 @@ impl<'d> AttReq<'d> { Ok(()) } - pub fn decode(data: &'d [u8]) -> Result, codec::Error> { - let mut r = ReadCursor::new(data); - let opcode: u8 = r.read()?; - AttReq::decode_with_opcode(opcode, r) - } - - pub fn decode_with_opcode(opcode: u8, r: ReadCursor<'d>) -> Result, codec::Error> { + fn decode_with_opcode(opcode: u8, r: ReadCursor<'d>) -> Result { let payload = r.remaining(); match opcode { ATT_READ_BY_GROUP_TYPE_REQ => { @@ -549,12 +704,6 @@ impl<'d> AttReq<'d> { Ok(Self::Write { handle, data }) } - ATT_WRITE_CMD => { - let handle = (payload[0] as u16) + ((payload[1] as u16) << 8); - let data = &payload[2..]; - - Ok(Self::WriteCmd { handle, data }) - } ATT_EXCHANGE_MTU_REQ => { let mtu = (payload[0] as u16) + ((payload[1] as u16) << 8); Ok(Self::ExchangeMtu { mtu }) @@ -607,3 +756,121 @@ impl<'d> AttReq<'d> { } } } + +impl<'d> AttCmd<'d> { + fn size(&self) -> usize { + 1 + match self { + Self::Write { handle, data } => 2 + data.len(), + } + } + + fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { + let mut w = WriteCursor::new(dest); + match self { + Self::Write { handle, data } => { + w.write(ATT_WRITE_REQ)?; + w.write(*handle)?; + w.append(data)?; + } + } + Ok(()) + } + + fn decode_with_opcode(opcode: u8, r: ReadCursor<'d>) -> Result { + let payload = r.remaining(); + match opcode { + ATT_WRITE_CMD => { + let handle = (payload[0] as u16) + ((payload[1] as u16) << 8); + let data = &payload[2..]; + + Ok(Self::Write { handle, data }) + } + code => { + warn!("[att] unknown opcode {:x}", code); + Err(codec::Error::InvalidValue) + } + } + } +} + +impl AttCfm { + fn size(&self) -> usize { + 1 + } + + fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { + let mut w = WriteCursor::new(dest); + match self { + Self::ConfirmIndication => { + w.write(ATT_HANDLE_VALUE_CMF)?; + } + } + Ok(()) + } + + fn decode_with_opcode(opcode: u8, r: ReadCursor<'_>) -> Result { + let payload = r.remaining(); + match opcode { + ATT_HANDLE_VALUE_CMF => Ok(Self::ConfirmIndication), + code => { + warn!("[att] unknown opcode {:x}", code); + Err(codec::Error::InvalidValue) + } + } + } +} + +impl<'d> Att<'d> { + /// Get the wire-size of the ATT PDU + pub fn size(&self) -> usize { + match self { + Self::Client(client) => client.size(), + Self::Server(server) => server.size(), + } + } + + /// Encode the ATT PDU into a byte buffer + pub fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { + match self { + Self::Client(client) => client.encode(dest), + Self::Server(server) => server.encode(dest), + } + } + + /// Decode an ATT PDU from a byte buffer + pub fn decode(data: &'d [u8]) -> Result, codec::Error> { + let mut r = ReadCursor::new(data); + let opcode: u8 = r.read()?; + if opcode % 2 == 0 { + let client = AttClient::decode_with_opcode(opcode, r)?; + Ok(Att::Client(client)) + } else { + let server = AttServer::decode_with_opcode(opcode, r)?; + Ok(Att::Server(server)) + } + } +} + +impl From for AttErrorCode { + fn from(e: codec::Error) -> Self { + AttErrorCode::INVALID_PDU + } +} + +impl codec::Type for Att<'_> { + fn size(&self) -> usize { + Self::size(self) + } +} + +impl codec::Encode for Att<'_> { + fn encode(&self, dest: &mut [u8]) -> Result<(), codec::Error> { + Self::encode(self, dest) + } +} + +impl<'d> codec::Decode<'d> for Att<'d> { + fn decode(data: &'d [u8]) -> Result { + Self::decode(data) + } +} diff --git a/host/src/attribute.rs b/host/src/attribute.rs index 2773ab7d..7db769a1 100644 --- a/host/src/attribute.rs +++ b/host/src/attribute.rs @@ -612,7 +612,8 @@ impl Drop for ServiceBuilder<'_, '_, M, MAX> { #[cfg_attr(feature = "defmt", derive(defmt::Format))] #[derive(Clone, Copy, Debug, PartialEq)] pub struct Characteristic { - pub(crate) cccd_handle: Option, + /// Handle value assigned to the Client Characteristic Configuration Descriptor (if any) + pub cccd_handle: Option, /// Handle value assigned to this characteristic when it is added to the Gatt Attribute Table pub handle: u16, pub(crate) phantom: PhantomData, diff --git a/host/src/attribute_server.rs b/host/src/attribute_server.rs index 4dab556b..92342148 100644 --- a/host/src/attribute_server.rs +++ b/host/src/attribute_server.rs @@ -4,7 +4,7 @@ use bt_hci::param::ConnHandle; use embassy_sync::blocking_mutex::raw::RawMutex; use embassy_sync::blocking_mutex::Mutex; -use crate::att::{self, AttErrorCode, AttReq}; +use crate::att::{self, AttClient, AttCmd, AttErrorCode, AttReq}; use crate::attribute::{AttributeData, AttributeTable}; use crate::cursor::WriteCursor; use crate::prelude::Connection; @@ -26,7 +26,7 @@ pub(crate) mod sealed { use super::*; pub trait DynamicAttributeServer { - fn process(&self, connection: &Connection, packet: &AttReq, rx: &mut [u8]) -> Result, Error>; + fn process(&self, connection: &Connection, packet: &AttClient, rx: &mut [u8]) -> Result, Error>; fn should_notify(&self, connection: &Connection, cccd_handle: u16) -> bool; fn set(&self, characteristic: u16, input: &[u8]) -> Result<(), Error>; } @@ -37,7 +37,7 @@ pub trait DynamicAttributeServer: sealed::DynamicAttributeServer {} impl DynamicAttributeServer for AttributeServer<'_, M, MAX> {} impl sealed::DynamicAttributeServer for AttributeServer<'_, M, MAX> { - fn process(&self, connection: &Connection, packet: &AttReq, rx: &mut [u8]) -> Result, Error> { + fn process(&self, connection: &Connection, packet: &AttClient, rx: &mut [u8]) -> Result, Error> { let res = AttributeServer::process(self, connection, packet, rx)?; Ok(res) } @@ -441,51 +441,59 @@ impl<'values, M: RawMutex, const MAX: usize> AttributeServer<'values, M, MAX> { pub fn process( &self, connection: &Connection, - packet: &AttReq, + packet: &AttClient, rx: &mut [u8], ) -> Result, codec::Error> { let len = match packet { - AttReq::ReadByType { + AttClient::Request(AttReq::ReadByType { start, end, attribute_type, - } => self.handle_read_by_type_req(connection, rx, *start, *end, attribute_type)?, + }) => self.handle_read_by_type_req(connection, rx, *start, *end, attribute_type)?, - AttReq::ReadByGroupType { start, end, group_type } => { + AttClient::Request(AttReq::ReadByGroupType { start, end, group_type }) => { self.handle_read_by_group_type_req(connection, rx, *start, *end, group_type)? } - AttReq::FindInformation { + AttClient::Request(AttReq::FindInformation { start_handle, end_handle, - } => self.handle_find_information(rx, *start_handle, *end_handle)?, + }) => self.handle_find_information(rx, *start_handle, *end_handle)?, - AttReq::Read { handle } => self.handle_read_req(connection, rx, *handle)?, + AttClient::Request(AttReq::Read { handle }) => self.handle_read_req(connection, rx, *handle)?, - AttReq::WriteCmd { handle, data } => { + AttClient::Command(AttCmd::Write { handle, data }) => { self.handle_write_cmd(connection, rx, *handle, data)?; 0 } - AttReq::Write { handle, data } => self.handle_write_req(connection, rx, *handle, data)?, + AttClient::Request(AttReq::Write { handle, data }) => { + self.handle_write_req(connection, rx, *handle, data)? + } - AttReq::ExchangeMtu { mtu } => 0, // Done outside, + AttClient::Request(AttReq::ExchangeMtu { mtu }) => 0, // Done outside, - AttReq::FindByTypeValue { + AttClient::Request(AttReq::FindByTypeValue { start_handle, end_handle, att_type, att_value, - } => self.handle_find_type_value(rx, *start_handle, *end_handle, *att_type, att_value)?, + }) => self.handle_find_type_value(rx, *start_handle, *end_handle, *att_type, att_value)?, - AttReq::PrepareWrite { handle, offset, value } => { + AttClient::Request(AttReq::PrepareWrite { handle, offset, value }) => { self.handle_prepare_write(connection, rx, *handle, *offset, value)? } - AttReq::ExecuteWrite { flags } => self.handle_execute_write(rx, *flags)?, + AttClient::Request(AttReq::ExecuteWrite { flags }) => self.handle_execute_write(rx, *flags)?, + + AttClient::Request(AttReq::ReadBlob { handle, offset }) => { + self.handle_read_blob(connection, rx, *handle, *offset)? + } - AttReq::ReadBlob { handle, offset } => self.handle_read_blob(connection, rx, *handle, *offset)?, + AttClient::Request(AttReq::ReadMultiple { handles }) => { + self.handle_read_multiple(connection, rx, handles)? + } - AttReq::ReadMultiple { handles } => self.handle_read_multiple(connection, rx, handles)?, + AttClient::Confirmation(_) => 0, }; if len > 0 { Ok(Some(len)) diff --git a/host/src/gatt.rs b/host/src/gatt.rs index 43daf6d1..995281bb 100644 --- a/host/src/gatt.rs +++ b/host/src/gatt.rs @@ -14,7 +14,7 @@ use embassy_sync::channel::{Channel, DynamicReceiver}; use embassy_sync::pubsub::{self, PubSubChannel, WaitResult}; use heapless::Vec; -use crate::att::{self, AttReq, AttRsp, ATT_HANDLE_VALUE_NTF}; +use crate::att::{self, Att, AttClient, AttCmd, AttReq, AttRsp, AttServer, AttUns, ATT_HANDLE_VALUE_NTF}; use crate::attribute::{AttributeData, Characteristic, CharacteristicProp, Uuid, CCCD}; use crate::attribute_server::{AttributeServer, DynamicAttributeServer}; use crate::connection::Connection; @@ -35,19 +35,33 @@ impl<'stack> GattData<'stack> { Self { pdu, connection } } - /// Get the raw request. - pub fn request(&self) -> AttReq<'_> { - // We know it has been checked, therefore this cannot fail - unwrap!(AttReq::decode(self.pdu.as_ref())) + /// Get the raw incoming ATT PDU. + pub fn incoming(&self) -> AttClient<'_> { + // We know that: + // - The PDU is decodable, as it was already decoded once before adding it to the connection queue + // - The PDU is of type `Att::Client` because only those types of PDUs are added to the connection queue + let att = unwrap!(Att::decode(self.pdu.as_ref())); + let Att::Client(client) = att else { + unreachable!("Expected Att::Client, got {:?}", att) + }; + + client } /// Respond directly to request. pub async fn reply(self, rsp: AttRsp<'_>) -> Result<(), Error> { - let pdu = respond(&self.connection, rsp)?; + let pdu = send(&self.connection, AttServer::Response(rsp))?; self.connection.send(pdu).await; Ok(()) } + /// Send an unsolicited ATT PDU without having a request (e.g. notification or indication) + pub async fn send_unsolicited(connection: &Connection<'_>, uns: AttUns<'_>) -> Result<(), Error> { + let pdu = send(connection, AttServer::Unsolicited(uns))?; + connection.send(pdu).await; + Ok(()) + } + /// Handle the GATT data. /// /// May return an event that should be replied/processed. Uses the attribute server to @@ -56,30 +70,30 @@ impl<'stack> GattData<'stack> { self, server: &'m AttributeServer<'server, M, MAX>, ) -> Result>, Error> { - let att = self.request(); + let att = self.incoming(); match att { - AttReq::Write { handle, data: _ } => Ok(Some(GattEvent::Write(WriteEvent { + AttClient::Request(AttReq::Write { handle, data: _ }) => Ok(Some(GattEvent::Write(WriteEvent { value_handle: handle, pdu: Some(self.pdu), connection: self.connection, server, }))), - AttReq::WriteCmd { handle, data: _ } => Ok(Some(GattEvent::Write(WriteEvent { + AttClient::Command(AttCmd::Write { handle, data: _ }) => Ok(Some(GattEvent::Write(WriteEvent { value_handle: handle, pdu: Some(self.pdu), connection: self.connection, server, }))), - AttReq::Read { handle } => Ok(Some(GattEvent::Read(ReadEvent { + AttClient::Request(AttReq::Read { handle }) => Ok(Some(GattEvent::Read(ReadEvent { value_handle: handle, pdu: Some(self.pdu), connection: self.connection, server, }))), - AttReq::ReadBlob { handle, offset } => Ok(Some(GattEvent::Read(ReadEvent { + AttClient::Request(AttReq::ReadBlob { handle, offset }) => Ok(Some(GattEvent::Read(ReadEvent { value_handle: handle, pdu: Some(self.pdu), connection: self.connection, @@ -227,7 +241,12 @@ fn process_accept<'stack>( connection: &Connection<'stack>, server: &dyn DynamicAttributeServer, ) -> Result, Error> { - let att = unwrap!(AttReq::decode(pdu.as_ref())); + // - The PDU is decodable, as it was already decoded once before adding it to the connection queue + // - The PDU is of type `Att::Client` because only those types of PDUs are added to the connection queue + let att = unwrap!(Att::decode(pdu.as_ref())); + let Att::Client(att) = att else { + unreachable!("Expected Att::Client, got {:?}", att) + }; let mut tx = connection.alloc_tx()?; let mut w = WriteCursor::new(tx.as_mut()); let (mut header, mut data) = w.split(4)?; @@ -254,15 +273,15 @@ fn process_reject<'stack>( // We know it has been checked, therefore this cannot fail let request = pdu.as_ref()[0]; let rsp = AttRsp::Error { request, handle, code }; - let pdu = respond(connection, rsp)?; + let pdu = send(connection, AttServer::Response(rsp))?; Ok(Reply::new(connection.clone(), Some(pdu))) } -fn respond<'stack>(conn: &Connection<'stack>, rsp: AttRsp<'_>) -> Result { +fn send<'stack>(conn: &Connection<'stack>, att: AttServer<'_>) -> Result { let mut tx = conn.alloc_tx()?; let mut w = WriteCursor::new(tx.as_mut()); let (mut header, mut data) = w.split(4)?; - data.write(rsp)?; + data.write(Att::Server(att))?; let mtu = conn.get_att_mtu(); data.truncate(mtu as usize); @@ -383,15 +402,17 @@ impl<'reference, T: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz for GattClient<'reference, T, MAX_SERVICES, L2CAP_MTU> { async fn request(&self, req: AttReq<'_>) -> Result> { + let data = Att::Client(AttClient::Request(req)); + let header = L2capHeader { channel: crate::types::l2cap::L2CAP_CID_ATT, - length: req.size() as u16, + length: data.size() as u16, }; let mut buf = [0; L2CAP_MTU]; let mut w = WriteCursor::new(&mut buf); w.write_hci(&header)?; - w.write(req)?; + w.write(data)?; let mut grant = self .stack @@ -419,9 +440,9 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz let mut buf = [0; 7]; let mut w = WriteCursor::new(&mut buf); w.write_hci(&l2cap)?; - w.write(att::AttReq::ExchangeMtu { + w.write(att::Att::Client(att::AttClient::Request(att::AttReq::ExchangeMtu { mtu: L2CAP_MTU as u16 - 4, - })?; + })))?; let mut grant = stack.host.l2cap(connection.handle(), w.len() as u16, 1).await?; grant.send(w.finish()).await?; @@ -455,7 +476,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz }; let pdu = self.request(data).await?; - let res = AttRsp::decode(pdu.as_ref())?; + let res = Self::response(pdu.as_ref())?; match res { AttRsp::Error { request, handle, code } => { if code == att::AttErrorCode::ATTRIBUTE_NOT_FOUND { @@ -509,7 +530,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz }; let pdu = self.request(data).await?; - match AttRsp::decode(pdu.as_ref())? { + match Self::response(pdu.as_ref())? { AttRsp::ReadByType { mut it } => { while let Some(Ok((handle, item))) = it.next() { if item.len() < 5 { @@ -563,7 +584,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz let pdu = self.request(data).await?; - match AttRsp::decode(pdu.as_ref())? { + match Self::response(pdu.as_ref())? { AttRsp::ReadByType { mut it } => { if let Some(Ok((handle, item))) = it.next() { Ok(( @@ -593,7 +614,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz let pdu = self.request(data).await?; - match AttRsp::decode(pdu.as_ref())? { + match Self::response(pdu.as_ref())? { AttRsp::Read { data } => { let to_copy = data.len().min(dest.len()); dest[..to_copy].copy_from_slice(&data[..to_copy]); @@ -621,7 +642,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz let pdu = self.request(data).await?; - match AttRsp::decode(pdu.as_ref())? { + match Self::response(pdu.as_ref())? { AttRsp::ReadByType { mut it } => { let mut to_copy = 0; if let Some(item) = it.next() { @@ -648,7 +669,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz }; let pdu = self.request(data).await?; - match AttRsp::decode(pdu.as_ref())? { + match Self::response(pdu.as_ref())? { AttRsp::Write => Ok(()), AttRsp::Error { request, handle, code } => Err(Error::Att(code).into()), _ => Err(Error::InvalidValue.into()), @@ -673,7 +694,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz // set the CCCD let pdu = self.request(data).await?; - match AttRsp::decode(pdu.as_ref())? { + match Self::response(pdu.as_ref())? { AttRsp::Write => { let listener = self .notifications @@ -703,7 +724,7 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz // set the CCCD let pdu = self.request(data).await?; - match AttRsp::decode(pdu.as_ref())? { + match Self::response(pdu.as_ref())? { AttRsp::Write => Ok(()), AttRsp::Error { request, handle, code } => Err(Error::Att(code).into()), _ => Err(Error::InvalidValue.into()), @@ -744,4 +765,12 @@ impl<'reference, C: Controller, const MAX_SERVICES: usize, const L2CAP_MTU: usiz } } } + + fn response<'a>(data: &'a [u8]) -> Result, BleHostError> { + let att = Att::decode(data)?; + match att { + Att::Server(AttServer::Response(rsp)) => Ok(rsp), + _ => Err(Error::InvalidValue.into()), + } + } } diff --git a/host/src/host.rs b/host/src/host.rs index 971cc938..b6d0a29c 100644 --- a/host/src/host.rs +++ b/host/src/host.rs @@ -33,6 +33,7 @@ use embassy_sync::waitqueue::WakerRegistration; use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel}; use futures::pin_mut; +use crate::att::{AttClient, AttServer}; use crate::channel_manager::{ChannelManager, ChannelStorage, PacketChannel}; use crate::command::CommandState; #[cfg(feature = "gatt")] @@ -348,10 +349,10 @@ where // Handle ATT MTU exchange here since it doesn't strictly require // gatt to be enabled. let a = att::Att::decode(&packet.as_ref()[..header.length as usize]); - if let Ok(att::Att::Req(att::AttReq::ExchangeMtu { mtu })) = a { + if let Ok(att::Att::Client(AttClient::Request(att::AttReq::ExchangeMtu { mtu }))) = a { let mtu = self.connections.exchange_att_mtu(acl.handle(), mtu); - let rsp = att::AttRsp::ExchangeMtu { mtu }; + let rsp = att::Att::Server(AttServer::Response(att::AttRsp::ExchangeMtu { mtu })); let l2cap = L2capHeader { channel: L2CAP_CID_ATT, length: 3, @@ -364,19 +365,19 @@ where info!("[host] agreed att MTU of {}", mtu); let len = w.len(); self.connections.try_outbound(acl.handle(), Pdu::new(packet, len))?; - } else if let Ok(att::Att::Rsp(att::AttRsp::ExchangeMtu { mtu })) = a { + } else if let Ok(att::Att::Server(AttServer::Response(att::AttRsp::ExchangeMtu { mtu }))) = a { info!("[host] remote agreed att MTU of {}", mtu); self.connections.exchange_att_mtu(acl.handle(), mtu); } else { #[cfg(feature = "gatt")] match a { - Ok(att::Att::Req(_)) => { + Ok(att::Att::Client(_)) => { let event = ConnectionEventData::Gatt { data: Pdu::new(packet, header.length as usize), }; self.connections.post_handle_event(acl.handle(), event)?; } - Ok(att::Att::Rsp(_)) => { + Ok(att::Att::Server(_)) => { if let Err(e) = self .att_client .try_send((acl.handle(), Pdu::new(packet, header.length as usize))) diff --git a/host/src/lib.rs b/host/src/lib.rs index 67ec0cf4..43278730 100644 --- a/host/src/lib.rs +++ b/host/src/lib.rs @@ -29,7 +29,7 @@ mod fmt; #[cfg(not(any(feature = "central", feature = "peripheral")))] compile_error!("Must enable at least one of the `central` or `peripheral` features"); -mod att; +pub mod att; #[cfg(feature = "central")] pub mod central; mod channel_manager;