From 14ea91c285eeba7d66a042a24f4f30c41c97abdd Mon Sep 17 00:00:00 2001 From: Finn Behrens Date: Mon, 19 Apr 2021 12:59:41 +0200 Subject: [PATCH] Rust NetDevice and NetDeviceOperationsVtable struct also adds drivers/net/dummy_rs.rs Signed-off-by: Finn Behrens --- drivers/net/Kconfig | 16 + drivers/net/Makefile | 1 + drivers/net/dummy_rs.rs | 228 +++++++++ rust/helpers.c | 33 +- rust/kernel/bindings_helper.h | 8 + rust/kernel/error.rs | 26 +- rust/kernel/lib.rs | 3 +- rust/kernel/net/device.rs | 890 ++++++++++++++++++++++++++++++++++ rust/kernel/net/ethtool.rs | 262 ++++++++++ rust/kernel/net/mod.rs | 70 +++ rust/kernel/net/netlink.rs | 138 ++++++ rust/kernel/net/rtnl.rs | 141 ++++++ rust/kernel/net/skbuff.rs | 151 ++++++ rust/module.rs | 193 ++++++++ 14 files changed, 2157 insertions(+), 3 deletions(-) create mode 100644 drivers/net/dummy_rs.rs create mode 100644 rust/kernel/net/device.rs create mode 100644 rust/kernel/net/ethtool.rs create mode 100644 rust/kernel/net/mod.rs create mode 100644 rust/kernel/net/netlink.rs create mode 100644 rust/kernel/net/rtnl.rs create mode 100644 rust/kernel/net/skbuff.rs diff --git a/drivers/net/Kconfig b/drivers/net/Kconfig index bcd31f458d1acd..3e7fc2c55a7ab4 100644 --- a/drivers/net/Kconfig +++ b/drivers/net/Kconfig @@ -72,6 +72,22 @@ config DUMMY To compile this driver as a module, choose M here: the module will be called dummy. +config DUMMY_RS + tristate "Dummy net driver support" + depends on HAS_RUST + help + This is essentially a bit-bucket device (i.e. traffic you send to + this device is consigned into oblivion) with a configurable IP + address. It is most commonly used in order to make your currently + inactive SLIP address seem like a real address for local programs. + If you use SLIP or PPP, you might want to say Y here. It won't + enlarge your kernel. What a deal. Read about it in the Network + Administrator's Guide, available from + . + + To compile this driver as a module, choose M here: the module + will be called dummy_rs. + config WIREGUARD tristate "WireGuard secure network tunnel" depends on NET && INET diff --git a/drivers/net/Makefile b/drivers/net/Makefile index f4990ff32fa4cc..8c2f8e10e49617 100644 --- a/drivers/net/Makefile +++ b/drivers/net/Makefile @@ -10,6 +10,7 @@ obj-$(CONFIG_BONDING) += bonding/ obj-$(CONFIG_IPVLAN) += ipvlan/ obj-$(CONFIG_IPVTAP) += ipvlan/ obj-$(CONFIG_DUMMY) += dummy.o +obj-$(CONFIG_DUMMY_RS) += dummy_rs.o obj-$(CONFIG_WIREGUARD) += wireguard/ obj-$(CONFIG_EQUALIZER) += eql.o obj-$(CONFIG_IFB) += ifb.o diff --git a/drivers/net/dummy_rs.rs b/drivers/net/dummy_rs.rs new file mode 100644 index 00000000000000..552aebc7e18fbc --- /dev/null +++ b/drivers/net/dummy_rs.rs @@ -0,0 +1,228 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Rust dummy network driver +//! +//! This is a demonstration of what a small driver looks like in Rust, based on drivers/net/dummy.c. +//! This code is provided as a demonstration only, not as a proposal to mass-rewrite existing drivers in Rust + +// TODO: copyright (see ./dummy.c) + + +#![no_std] +#![feature(allocator_api, global_asm)] + +use core::ops::Deref; + +use kernel::{net::netlink::{NlAttrVec, NlExtAck}, prelude::*}; +use kernel::net::prelude::*; +use kernel::net::device; +use kernel::net::rtnl; +use kernel::Error; + +module! { + type: RustNetDummy, + name: b"dummy_rs", + author: b"Rust for Linux Contributors", + description: b"Rust dummy network driver", + license: b"GPL v2", + alias: b"rtnl-link-dummy_rs", + params: { + numdummies: usize { + default: 0, + permissions: 0o644, + description: b"Number of dummy_rs pseudo devices", + }, + }, +} + +fn setup(dev: &mut NetDevice) { + pr_info!("called rtnl_setup"); + dev.ether_setup(); + + // Fill in device structure with ethernet-generic values. + dev.add_flag(device::Iff::NOARP); + dev.remove_flag(device::Iff::MULTICAST); + + dev.add_private_flag(device::IffPriv::LIVE_ADDR_CHANGE); + dev.add_private_flag(device::IffPriv::NO_QUEUE); + + let mut feature = device::feature::NetIF::new(); + + //feature.add(device::feature::NETIF_F_SG); + //feature.add(device::feature::NETIF_F_FRAGLIST_BIT as u64); + feature += device::feature::NETIF_F_SG; + feature += device::feature::NETIF_F_FRAGLIST; + feature += device::feature::NETIF_F_GSO_SOFTWARE; + feature += device::feature::NETIF_F_HW_CSUM; + feature += device::feature::NETIF_F_HIGHDMA; + feature += device::feature::NETIF_F_LLTX; + feature += device::feature::NETIF_F_GSO_ENCAP_ALL; + + dev.set_features(feature); + dev.set_hw_features(feature); + dev.set_hw_enc_features(feature); + + dev.hw_addr_random(); + dev.set_mtu(0, 0); +} + +fn validate(tb: &NlAttrVec, data: &NlAttrVec, ext_ack: &NlExtAck) -> KernelResult<()> { + pr_info!("validate nlattr"); + if let Some(addr) = tb.get(kernel::bindings::IFLA_ADDRESS) { + if addr.nla_len() != kernel::net::netlink::ETH_ALEN { + return Err(Error::EINVAL); + } + if !addr.is_valid_ether_addr() { + return Err(Error::EADDRNOTAVAIL); + } + } else { + pr_info!("no IFLA_ADDRESS in nlattr"); + } + pr_info!("valid nlattr"); + Ok(()) +} + +rtnl_link_ops! { + kind: b"dummy_rs", + type: DummyRsDev, + setup: setup, + validate: validate, +} + +struct RustNetDummy { + dev: NetDevice, +} + + +impl KernelModule for RustNetDummy { + fn init() -> KernelResult { + { + let lock =THIS_MODULE.kernel_param_lock(); + pr_info!("Rust Network Dummy with {} pseudo devices\n", numdummies.read(&lock)); + } + + unsafe { dummy_rs_rtnl_link_ops.register() }?; + + let mut dev = NetDevice::new(DummyRsDev, kernel::cstr!("dummyrs%d"), kernel::net::device::NetNameAssingType::Enum, 1, 1)?; + dev.set_rtnl_ops( unsafe { &dummy_rs_rtnl_link_ops }); + + if let Err(e) = dev.register() { + pr_warn!("could not register: {}", e.to_kernel_errno()); + return Err(e); + } + pr_info!("device registered"); + + Ok(RustNetDummy { + dev, + }) + } +} + +impl Drop for RustNetDummy { + fn drop(&mut self) { + pr_info!("remove rust net dummy"); + + + unsafe { + //let ptr = &dummy_rs_rtnl_link_ops as *const _ as *mut kernel::bindings::rtnl_link_ops; + pr_info!("rtnl: {:#?}", &dummy_rs_rtnl_link_ops.0); + let ptr = &dummy_rs_rtnl_link_ops.get_ptr(); + pr_info!("rtnl_link_ops_ptr: {:?}", ptr); + } + // TODO rtnl_link_unregister + // TODO: remove unsafe somehow + unsafe { dummy_rs_rtnl_link_ops.unregister() }; + } +} + + +struct DummyRsDev; + +impl NetDeviceOps for DummyRsDev { + kernel::declare_net_device_ops!( + get_stats64, + change_carrier, + validate_addr, + set_mac_addr, + set_rx_mode + ); + + fn init(dev: &mut NetDevice) -> KernelResult<()> { + dev.set_new_pcpu_lstats()?; + Ok(()) + } + + fn uninit(dev: &mut NetDevice) { + unsafe { dev.free_lstats() }; + } + + fn start_xmit(skb: SkBuff, dev: &mut NetDevice) -> kernel::net::device::NetdevTX { + let mut skb = skb; + + // TODO: dev_lstatt_add(dev, skb->len) + dev.lstats_add(skb.len()); + + skb.tx_timestamp(); + drop(skb); + + pr_info!("start_xmit called"); + + kernel::net::device::NetdevTX::TX_OK + } + + fn get_stats64(dev: &NetDevice, stats: &mut rtnl::RtnlLinkStats64) { + pr_info!("get stats64"); + stats.dev_read(dev); + } + + fn change_carrier(dev: &mut NetDevice, new_carrier: bool) -> KernelResult<()> { + dev.carrier_set(new_carrier); + + Ok(()) + } + + fn validate_addr(dev: &NetDevice) -> KernelResult<()> { + pr_info!("eth_validate_addr"); + device::helpers::eth_validate_addr(dev) + } + + fn set_mac_addr(dev: &mut NetDevice, p: *mut kernel::c_types::c_void) -> KernelResult<()> { + device::helpers::eth_mac_addr(dev, p) + } + + // Someting about faking multicast + fn set_rx_mode(dev: &mut NetDevice) { + pr_info!("set_rx_mode"); + } +} + +impl NetDeviceAdapter for DummyRsDev { + type Inner = Self; + + type Ops = Self; + + type EthOps = Self; + + fn setup(dev: &mut NetDevice) { + pr_info!("called netdev_setup"); + setup(dev); + //dev.set_rtnl_ops( unsafe { &dummy_rs_rtnl_link_ops }); + } +} + +impl EthToolOps for DummyRsDev { + kernel::declare_eth_tool_ops!(get_drvinfo, get_ts_info); + + fn get_drvinfo(_dev: &NetDevice, info: &mut ethtool::EthtoolDrvinfo) { + // TODO: how to do this more efficient without unsafe? + // FIXME: !! + let info: &kernel::bindings::ethtool_drvinfo = info.deref(); + unsafe { + kernel::bindings::strlcpy(&(info.driver) as *const _ as *mut i8, b"dummy_rs\0" as *const _ as *mut i8, 32); + } + } + + fn get_ts_info(dev: &NetDevice, info: &mut ethtool::EthToolTsInfo) -> KernelResult<()> { + kernel::net::ethtool::helpers::ethtool_op_get_ts_info(dev, info) + } +} diff --git a/rust/helpers.c b/rust/helpers.c index f38ed02438ae9b..a3ea572c416dd1 100644 --- a/rust/helpers.c +++ b/rust/helpers.c @@ -7,6 +7,9 @@ #include #include #include +#include +#include +#include void rust_helper_BUG(void) { @@ -105,8 +108,36 @@ size_t rust_helper_copy_to_iter(const void *addr, size_t bytes, struct iov_iter } EXPORT_SYMBOL_GPL(rust_helper_copy_to_iter); -#if !defined(CONFIG_ARM) +void *rust_helper_netdev_priv(struct net_device *dev) +{ + return netdev_priv(dev); +} +EXPORT_SYMBOL_GPL(rust_helper_netdev_priv); + +void rust_helper_eth_hw_addr_random(struct net_device *dev) +{ + eth_hw_addr_random(dev); +} +EXPORT_SYMBOL_GPL(rust_helper_eth_hw_addr_random); + +int rust_helper_net_device_set_new_lstats(struct net_device *dev) +{ + dev->lstats = netdev_alloc_pcpu_stats(struct pcpu_lstats); + if (!dev->lstats) + return -ENOMEM; + + return 0; +} +EXPORT_SYMBOL_GPL(rust_helper_net_device_set_new_lstats); + +void rust_helper_dev_lstats_add(struct net_device *dev, unsigned int len) +{ + dev_lstats_add(dev, len); +} +EXPORT_SYMBOL_GPL(rust_helper_dev_lstats_add); + // See https://github.com/rust-lang/rust-bindgen/issues/1671 +#if !defined(CONFIG_ARM) static_assert(__builtin_types_compatible_p(size_t, uintptr_t), "size_t must match uintptr_t, what architecture is this??"); #endif diff --git a/rust/kernel/bindings_helper.h b/rust/kernel/bindings_helper.h index ec052e54350fd9..af3a190471e1ed 100644 --- a/rust/kernel/bindings_helper.h +++ b/rust/kernel/bindings_helper.h @@ -2,6 +2,12 @@ #include #include +#include +#include +#include +#include +#include +#include #include #include #include @@ -17,3 +23,5 @@ // `bindgen` gets confused at certain things const gfp_t BINDINGS_GFP_KERNEL = GFP_KERNEL; const gfp_t BINDINGS___GFP_ZERO = __GFP_ZERO; + +const int BINDINGS_NLA_HDRLEN = NLA_HDRLEN; diff --git a/rust/kernel/error.rs b/rust/kernel/error.rs index 432d866232c13c..fdf1f8268794b2 100644 --- a/rust/kernel/error.rs +++ b/rust/kernel/error.rs @@ -6,7 +6,7 @@ use crate::{bindings, c_types}; use alloc::{alloc::AllocError, collections::TryReserveError}; -use core::{num::TryFromIntError, str::Utf8Error}; +use core::{convert::TryFrom, num::TryFromIntError, str::Utf8Error}; /// Generic integer kernel error. /// @@ -48,6 +48,9 @@ impl Error { /// Interrupted system call. pub const EINTR: Self = Error(-(bindings::EINTR as i32)); + /// Cannot assign requested address + pub const EADDRNOTAVAIL: Self = Error(-(bindings::EADDRNOTAVAIL as i32)); + /// Creates an [`Error`] from a kernel error code. pub fn from_kernel_errno(errno: c_types::c_int) -> Error { Error(errno) @@ -104,3 +107,24 @@ impl From for Error { Error::ENOMEM } } + +/// Used by the rtnl_link_ops macro to interface with C +pub fn c_from_kernel_result(r: KernelResult) -> T +where + T: TryFrom, + T::Error: core::fmt::Debug, +{ + match r { + Ok(v) => v, + Err(e) => T::try_from(e.to_kernel_errno()).unwrap(), + } +} + +#[macro_export] +macro_rules! c_from_kernel_result { + ($($tt:tt)*) => {{ + $crate::c_from_kernel_result((|| { + $($tt)* + })()) + }}; +} \ No newline at end of file diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs index 3799a03fcf875a..2254fe484e1cd2 100644 --- a/rust/kernel/lib.rs +++ b/rust/kernel/lib.rs @@ -41,6 +41,7 @@ pub mod c_types; pub mod chrdev; mod error; pub mod file_operations; +pub mod net; pub mod miscdev; pub mod pages; @@ -64,7 +65,7 @@ pub mod iov_iter; mod types; pub mod user_ptr; -pub use crate::error::{Error, KernelResult}; +pub use crate::error::{Error, KernelResult, c_from_kernel_result}; pub use crate::types::{CStr, Mode}; /// Page size defined in terms of the `PAGE_SHIFT` macro from C. diff --git a/rust/kernel/net/device.rs b/rust/kernel/net/device.rs new file mode 100644 index 00000000000000..e06007d833bb29 --- /dev/null +++ b/rust/kernel/net/device.rs @@ -0,0 +1,890 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Net Device Operations. +//! +//! C header: [`include/linux/netdevice.h`](../../../../include/linux/netdevice.h) + +use core::{marker, mem, ops::Deref, ops::DerefMut, ptr}; + +use crate::bindings; +use crate::{c_types, CStr}; +use crate::error::{Error, KernelResult}; +use crate::c_from_kernel_result; +use crate::sync::{CondVar, Ref, RefCounted}; +use crate::user_ptr::{UserSlicePtr, UserSlicePtrReader, UserSlicePtrWriter}; + +use super::ethtool::EthToolOps; +use super::rtnl::{RtnlLinkOps, RtnlLock, RtnlLinkStats64}; +use super::skbuff::SkBuff; + +/*pub(crate) fn from_kernel_result(r: KernelResult) -> T +where + T: TryFrom, + T::Error: core::fmt::Debug, +{ + match r { + Ok(v) => v, + Err(e) => T::try_from(e.to_kernel_errno()).unwrap(), + } +} +*/ + + +extern "C" { + #[allow(improper_ctypes)] + fn rust_helper_netdev_priv(dev: *const bindings::net_device) -> *mut c_types::c_void; + + #[allow(improper_ctypes)] + fn rust_helper_eth_hw_addr_random(dev: *const bindings::net_device); + + #[allow(improper_ctypes)] + fn rust_helper_net_device_set_new_lstats(dev: *mut bindings::net_device) -> c_types::c_int; + + #[allow(improper_ctypes)] + fn rust_helper_dev_lstats_add(dev: *mut bindings::net_device, len: u32); +} + +/// interface name assignment types (sysfs name_assign_type attribute) +#[repr(u8)] +pub enum NetNameAssingType { + Unknown = bindings::NET_NAME_UNKNOWN as u8, + Enum = bindings::NET_NAME_ENUM as u8, +} + +unsafe extern "C" fn setup_netdev_callback(dev: *mut bindings::net_device) { + // TODO: pass on dev + let mut dev = NetDevice::::from_ptr(dev); + dev.needs_free_netdev = true; // TODO: is this reasonable? + dev.netdev_ops = NetDeviceOperationsVtable::::build(); + dev.ethtool_ops = super::ethtool::EthToolOperationsVtable::::build(); + + T::setup(&mut dev); +} + +/// Wraps the kernel's `struct net_device`. +/// +/// # Invariants +/// +/// The pointer [`NetDevice::ptr`] is non-null and valid. +pub struct NetDevice { + ptr: *const bindings::net_device, + priv_data: marker::PhantomData, +} + +unsafe impl Sync for NetDevice {} +//unsafe impl Send for NetDevice {} + +impl NetDevice { + pub fn new(priv_data: T, format_name: CStr<'static>, name_assign_type: NetNameAssingType, txqs: u32, rxqs: u32) -> KernelResult { + let _lock = RtnlLock::lock(); + // Lock is hold + let dev = unsafe { Self::new_locked(priv_data, format_name, name_assign_type, txqs, rxqs) }; + dev + } + + pub unsafe fn new_locked(priv_data: T, format_name: CStr<'static>, name_assign_type: NetNameAssingType, txqs: u32, rxqs: u32) -> KernelResult { + // TODO: check for {t,r}xqs bigger 0 + let size = mem::size_of::() as i32; + + // Safety: TODO + let ptr = bindings::alloc_netdev_mqs(size, format_name.as_ptr() as _, name_assign_type as u8, Some(setup_netdev_callback::), txqs, rxqs); + if ptr.is_null() { + return Err(Error::ENOMEM); + } + + if size != 0 { + // Safety: T is valid and dest is created by alloc_netdev_mqs + let dest = rust_helper_netdev_priv(ptr) as *mut T; + ptr::write(dest, priv_data); + } + + Ok(Self { + ptr, + priv_data: marker::PhantomData::, + }) + } + + // TODO: pin? + pub fn get_priv_data(&self) -> &T { + // SAFETY: + let priv_ptr = unsafe { rust_helper_netdev_priv(self.ptr) } as *mut T; + + // SAFETY: ptr is valid and of type T + unsafe { priv_ptr.as_ref() }.unwrap() + } + + // TODO: pin? + pub unsafe fn get_priv_data_mut(&mut self) -> &mut T { + // SAFETY: + let priv_ptr = rust_helper_netdev_priv(self.ptr) as *mut T; + + // SAFETY: ptr is valid and of type T + priv_ptr.as_mut().unwrap() + } + + /// ether_setup - setup Ethernet network device + /// + /// Fill in the fields of the device structure with Ethernet-generic values. + pub fn ether_setup(&mut self) { + // SAFETY: self.ptr is valid + unsafe { bindings::ether_setup(self.ptr as *mut bindings::net_device) } + } + + /// hw_addr_random - Generate software assigned random Ethernet and set device flag + /// + /// Generate a random Ethernet address (MAC) to be used by a net device + /// and set addr_assign_type so the state can be read by sysfs and be + /// used by userspace. + pub fn hw_addr_random(&mut self) { + // SAFETY: self.ptr is valid + unsafe { rust_helper_eth_hw_addr_random(self.ptr) }; + } + + /// register - register a network device + /// + /// Take a completed network device structure and add it to the kernel + /// interfaces. A %NETDEV_REGISTER message is sent to the netdev notifier + /// chain. 0 is returned on success. A negative errno code is returned + /// on a failure to set up the device, or if the name is a duplicate. + /// + /// This is a wrapper around register_netdevice that takes the rtnl semaphore + /// and expands the device name if you passed a format string to + /// alloc_netdev. + pub fn register(&self) -> KernelResult<()> { + // SAFETY: self.ptr is valid + let err = unsafe { bindings::register_netdev(self.ptr as *mut bindings::net_device) }; + + if err != 0 { + Err(Error::from_kernel_errno(err)) + } else { + Ok(()) + } + } + + /// register_locked - register a network device if the RtnlLock is already hold + /// + /// Take a completed network device structure and add it to the kernel + /// interfaces. A %NETDEV_REGISTER message is sent to the netdev notifier + /// chain. 0 is returned on success. A negative errno code is returned + /// on a failure to set up the device, or if the name is a duplicate. + /// + /// Callers must hold the rtnl semaphore. You may want + /// [`register`] instead of this. + /// + /// BUGS: + /// The locking appears insufficient to guarantee two parallel registers + /// will not get the same name. + /// + /// # Safety + /// + /// caller must hold the RtnlLock and semaphore + pub unsafe fn register_locked(&self) -> KernelResult<()> { + let err = bindings::register_netdevice(self.ptr as *mut bindings::net_device); + + if err != 0 { + Err(Error::from_kernel_errno(err)) + } else { + Ok(()) + } + } + + /// set_rtnl_ops - set the rtnl_link_ops to a network interface + /// + /// Takes a static mut created with [`kernel::net::rtnl_link_ops!`] and assing it to self + pub fn set_rtnl_ops(&self, ops: &'static super::rtnl::RtnlLinkOps) { + // get rtnl_lock + let _lock = RtnlLock::lock(); + + // SAFETY: lock is hold + unsafe { self.set_rtnl_ops_locked(ops) } + } + + /// set_rtnl_ops_locked - set the rtnl_link_ops to a network interface, while the caller holds the rtnl_lock + /// + /// Takes a static mut created with [`kernel::net::rtnl_link_ops!`] and assing it to self + pub unsafe fn set_rtnl_ops_locked(&self, ops: &'static super::rtnl::RtnlLinkOps) { + // SAFETY: prt is valid if self is valid + let mut dev = (self.ptr as *mut bindings::net_device ).as_mut().unwrap(); + + //dev.rtnl_link_ops = (&ops.0) as *const _ as *mut _; + dev.rtnl_link_ops = ops.get_ptr() as *mut bindings::rtnl_link_ops; + } + + pub fn add_flag(&mut self, flag: Iff) { + // SAFETY: prt is valid if self is valid + let mut dev = unsafe { (self.ptr as *mut bindings::net_device).as_mut().unwrap() }; + + dev.flags |= flag as u32; + } + + pub fn remove_flag(&mut self, flag: Iff) { + // SAFETY: prt is valid if self is valid + let mut dev = unsafe { (self.ptr as *mut bindings::net_device).as_mut().unwrap() }; + + dev.flags &= !(flag as u32); + } + + pub fn add_private_flag(&mut self, flag: IffPriv) { + // SAFETY: prt is valid if self is valid + let mut dev = unsafe { (self.ptr as *mut bindings::net_device).as_mut().unwrap() }; + + dev.priv_flags |= flag as u32; + } + + pub fn remove_private_flag(&mut self, flag: IffPriv) { + // SAFETY: prt is valid if self is valid + let mut dev = unsafe { (self.ptr as *mut bindings::net_device).as_mut().unwrap() }; + + dev.priv_flags &= !(flag as u32); + } + + pub fn set_features(&mut self, features: feature::NetIF) { + // SAFETY: ptr is valid if self is valid + let mut dev = unsafe { (self.ptr as *mut bindings::net_device).as_mut() }.unwrap(); + + dev.features = features.into(); + } + + + pub fn get_features(&self) -> feature::NetIF { + // SAFETY: ptr is valid if self is valid + let dev = unsafe { self.ptr.as_ref() }.unwrap(); + + feature::NetIF::from(dev.features) + } + + pub fn set_hw_features(&mut self, features: feature::NetIF) { + // SAFETY: ptr is valid if self is valid + let mut dev = unsafe { (self.ptr as *mut bindings::net_device).as_mut() }.unwrap(); + + dev.hw_features = features.into(); + } + + pub fn get_hw_features(&self) -> feature::NetIF { + // SAFETY: ptr is valid if self is valid + let dev = unsafe { self.ptr.as_ref() }.unwrap(); + + feature::NetIF::from(dev.hw_features) + } + + pub fn set_hw_enc_features(&mut self, features: feature::NetIF) { + // SAFETY: ptr is valid if self is valid + let mut dev = unsafe { (self.ptr as *mut bindings::net_device).as_mut() }.unwrap(); + + dev.hw_enc_features = features.into(); + } + + pub fn get_hw_enc_features(&self) -> feature::NetIF { + // SAFETY: ptr is valid if self is valid + let dev = unsafe { self.ptr.as_ref() }.unwrap(); + + feature::NetIF::from(dev.hw_enc_features) + } + + pub fn set_mtu(&mut self, min: u32, max: u32) { + // SAFETY: prt is valid if self is valid + let mut dev = unsafe { (self.ptr as *mut bindings::net_device).as_mut().unwrap() }; + + dev.min_mtu = min; + dev.max_mtu = max; + } + + pub fn set_new_pcpu_lstats(&mut self) -> KernelResult<()> { + // SAFETY: calling c function + let ret = unsafe { + rust_helper_net_device_set_new_lstats(self.ptr as *mut bindings::net_device) + }; + + if ret != 0 { + Err(Error::from_kernel_errno(ret)) + } else { + Ok(()) + } + } + + /// # Safety + /// + /// Only call when the same device had set_new_pcpu_lstats called + pub unsafe fn free_lstats(&mut self) { + // SAFETY: self.ptr->lstats is valid if self is valid + + let net_device: &bindings::net_device = self.deref(); + + // __bindgen_anon_1: net_device__bindgen_ty_4 + + let lstats = net_device.__bindgen_anon_1.lstats; + if !lstats.is_null() { + unsafe { + bindings::free_percpu(lstats as *mut _) + } + } + } + + pub fn lstats_add(&mut self, len: u32) { + // SAFETD: calling c function + unsafe { + rust_helper_dev_lstats_add(self.ptr as *mut bindings::net_device, len); + } + } + + pub fn carrier_set(&mut self, status: bool) { + // SAFETY: self.ptr is valid if self is valid + if status { + unsafe { bindings::netif_carrier_on(self.ptr as *mut bindings::net_device) } + } else { + unsafe { bindings::netif_carrier_off(self.ptr as *mut bindings::net_device) } + } + } + + /// Constructs a new [`struct net_device`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + /// The private data must be non-null and valid for the lifetime of the object and be of type T. + pub unsafe fn from_ptr(ptr: *const bindings::net_device) -> Self { + // INVARIANTS: the safety contract ensures the type invariant will hold. + + // TODO: parse private data + Self { + ptr, + priv_data: marker::PhantomData::, + } + } + + pub unsafe fn get_ptr(&self) -> *const bindings::net_device { + self.ptr + } +} + +impl Deref for NetDevice { + type Target = bindings::net_device; + + fn deref(&self) -> &Self::Target { + // SAFETY: ptr is valid + unsafe { self.ptr.as_ref() }.unwrap() + } +} + +impl DerefMut for NetDevice { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: ptr is valid + unsafe { (self.ptr as *mut bindings::net_device).as_mut() }.unwrap() + } +} + +pub trait NetDeviceAdapter: Sized { + type Inner: Sized; // = Self + + type Ops: NetDeviceOps; + + type EthOps: EthToolOps; + + /// Callback to initialize the device + /// function table is setup by the abstraction based on Self::Ops + fn setup(dev: &mut NetDevice); +} + +#[repr(i32)] +/// maps to [`netdev_tx`] from the kernel. +pub enum NetdevTX { + TX_OK = bindings::netdev_tx_NETDEV_TX_OK, + TX_BUSY = bindings::netdev_tx_NETDEV_TX_BUSY, +} + + +unsafe extern "C" fn ndo_init_callback( + dev: *mut bindings::net_device +) -> c_types::c_int { + c_from_kernel_result! { + T::Ops::init(&mut NetDevice::::from_ptr(dev))?; + Ok(0) + } +} + +unsafe extern "C" fn ndo_uninit_callback( + dev: *mut bindings::net_device +) { + // SAFETY: pointer is valid as it comes form C + //let mut device = dev.as_mut().unwrap(); + T::Ops::uninit(&mut NetDevice::::from_ptr(dev)); +} + +unsafe extern "C" fn ndo_start_xmit_callback( + skb: *mut bindings::sk_buff, + dev: *mut bindings::net_device, +) -> bindings::netdev_tx_t { + //let device = NetDevice::::from_ptr(dev); + + T::Ops::start_xmit(SkBuff::from_ptr(skb), &mut NetDevice::from_ptr(dev)) as bindings::netdev_tx_t +} + +unsafe extern "C" fn ndo_get_stats64_callback(dev: *mut bindings::net_device, stats: *mut bindings::rtnl_link_stats64) { + T::Ops::get_stats64(&NetDevice::::from_ptr(dev), &mut RtnlLinkStats64::from_ptr(stats)); +} + +unsafe extern "C" fn ndo_change_carrier_callback(dev: *mut bindings::net_device, change_carrier: bool) -> c_types::c_int { + c_from_kernel_result! { + T::Ops::change_carrier(&mut NetDevice::::from_ptr(dev), change_carrier)?; + Ok(0) + } +} + +unsafe extern "C" fn ndo_validate_addr_callback(dev: *mut bindings::net_device) -> c_types::c_int { + c_from_kernel_result! { + T::Ops::validate_addr(&NetDevice::::from_ptr(dev))?; + Ok(0) + } +} + +unsafe extern "C" fn ndo_set_mac_address_callback(dev: *mut bindings::net_device, p: *mut c_types::c_void) -> c_types::c_int { + c_from_kernel_result! { + T::Ops::set_mac_addr(&mut NetDevice::::from_ptr(dev), p)?; + Ok(0) + } +} + +unsafe extern "C" fn ndo_set_rx_mode_callback(dev: *mut bindings::net_device) { + T::Ops::set_rx_mode(&mut NetDevice::::from_ptr(dev)) +} + +pub(crate) struct NetDeviceOperationsVtable(marker::PhantomData); + +impl NetDeviceOperationsVtable { + const VTABLE: bindings::net_device_ops = bindings::net_device_ops { + ndo_init: Some(ndo_init_callback::), + ndo_uninit: Some(ndo_uninit_callback::), + ndo_open: None, + ndo_stop: None, + ndo_start_xmit: Some(ndo_start_xmit_callback::), + ndo_features_check: None, + ndo_select_queue: None, + ndo_change_rx_flags: None, + ndo_set_rx_mode: if T::Ops::TO_USE.set_rx_mode { + Some(ndo_set_rx_mode_callback::) + } else { None }, + ndo_set_mac_address: if T::Ops::TO_USE.set_mac_addr { + Some(ndo_set_mac_address_callback::) + } else { None }, + ndo_validate_addr: if T::Ops::TO_USE.validate_addr { + Some(ndo_validate_addr_callback::) + } else { None }, + ndo_do_ioctl: None, + ndo_set_config: None, + ndo_change_mtu: None, + ndo_neigh_setup: None, + ndo_tx_timeout: None, + ndo_get_stats64: if T::Ops::TO_USE.get_stats64 { + Some(ndo_get_stats64_callback::) + } else { None }, + ndo_has_offload_stats: None, + ndo_get_offload_stats: None, + ndo_get_stats: None, + ndo_vlan_rx_add_vid: None, + ndo_vlan_rx_kill_vid: None, + + #[cfg(CONFIG_NET_POLL_CONTROLLER)] + ndo_poll_controller: None, + #[cfg(CONFIG_NET_POLL_CONTROLLER)] + ndo_netpoll_setup: None, + #[cfg(CONFIG_NET_POLL_CONTROLLER)] + ndo_netpoll_cleanup: None, + + ndo_set_vf_mac: None, + ndo_set_vf_vlan: None, + ndo_set_vf_rate: None, + ndo_set_vf_spoofchk: None, + ndo_set_vf_trust: None, + ndo_get_vf_config: None, + ndo_set_vf_link_state: None, + ndo_get_vf_stats: None, + ndo_set_vf_port: None, + ndo_get_vf_port: None, + ndo_get_vf_guid: None, + ndo_set_vf_guid: None, + ndo_set_vf_rss_query_en: None, + ndo_setup_tc: None, + + #[cfg(any(CONFIG_FCOE = "y", CONFIG_FCOE = "m"))] + ndo_fcoe_enable: None, + #[cfg(any(CONFIG_FCOE = "y", CONFIG_FCOE = "m"))] + ndo_fcoe_disable: None, + #[cfg(any(CONFIG_FCOE = "y", CONFIG_FCOE = "m"))] + ndo_fcoe_ddp_setup: None, + #[cfg(any(CONFIG_FCOE = "y", CONFIG_FCOE = "m"))] + ndo_fcoe_ddp_done: None, + #[cfg(any(CONFIG_FCOE = "y", CONFIG_FCOE = "m"))] + ndo_fcoe_ddp_target: None, + #[cfg(any(CONFIG_FCOE = "y", CONFIG_FCOE = "m"))] + ndo_fcoe_get_hbainfo: None, + + #[cfg(any(CONFIG_LIBFCOE = "y", CONFIG_LIBFCOE = "m"))] + ndo_fcoe_get_wwn: None, + + #[cfg(CONFIG_RFS_ACCEL)] + ndo_rx_flow_steer: None, + + ndo_add_slave: None, + ndo_del_slave: None, + ndo_get_xmit_slave: None, + ndo_sk_get_lower_dev: None, + ndo_fix_features: None, + ndo_set_features: None, + ndo_neigh_construct: None, + ndo_neigh_destroy: None, + ndo_fdb_add: None, + ndo_fdb_del: None, + ndo_fdb_dump: None, + ndo_fdb_get: None, + ndo_bridge_setlink: None, + ndo_bridge_getlink: None, + ndo_bridge_dellink: None, + ndo_change_carrier: if T::Ops::TO_USE.change_carrier { + Some(ndo_change_carrier_callback::) + } else { None }, + ndo_get_phys_port_id: None, + ndo_get_port_parent_id: None, + ndo_get_phys_port_name: None, + ndo_dfwd_add_station: None, + ndo_dfwd_del_station: None, + ndo_set_tx_maxrate: None, + ndo_get_iflink: None, + ndo_change_proto_down: None, + ndo_fill_metadata_dst: None, + ndo_set_rx_headroom: None, + ndo_bpf: None, + ndo_xdp_xmit: None, + ndo_xsk_wakeup: None, + ndo_get_devlink_port: None, + ndo_tunnel_ctl: None, + ndo_get_peer_dev: None, + }; + + + /// Builds an instance of [`struct net_device_ops`]. + /// + /// # Safety + /// + /// The caller must ensure that the adapter is compatible with the way the device is registered. + pub(crate) const unsafe fn build() -> &'static bindings::net_device_ops { + &Self::VTABLE + } +} + +/// Represents which fields of [`struct net_device_ops`] should pe populated with pointers. +pub struct ToUse { + /// The `ndo_change_carrier` field of [`struct net_device_ops`]. + pub change_carrier: bool, + + pub get_stats64: bool, + + pub validate_addr: bool, + + pub set_mac_addr: bool, + + pub set_rx_mode: bool, +} + +pub const USE_NONE: ToUse = ToUse { + change_carrier: false, + get_stats64: false, + validate_addr: false, + set_mac_addr: false, + set_rx_mode: false, +}; + +/// Defines the [`NetDeviceOps::TO_USE`] field based on a list of fields to be populated. +#[macro_export] +macro_rules! declare_net_device_ops { + () => { + const TO_USE: $crate::net::device::ToUse = $crate::net::device::USE_NONE; + }; + ($($i:ident),+) => { + const TO_USE: kernel::net::device::ToUse = + $crate::net::device::ToUse { + $($i: true),+ , + ..$crate::net::device::USE_NONE + }; + }; +} + + +/// Corresponds to the kernel's `struct net_device_ops`. +/// +/// You Implement this trait whenever you would create a `struct net_device_ops`.. +pub trait NetDeviceOps: Send + Sync + Sized { + /// The methods to use to populate [`struct net_device_ops`]. + const TO_USE: ToUse; + + /// This function is called once when a network device is registered. + /// The network device can use this for any late stage initialization + /// or semantic validation. It can fail with an error code which will + /// be propagated back to register_netdev. + fn init(dev: &mut NetDevice) -> KernelResult<()>; + + /// This function is called when device is unregistered or when registration + /// fails. It is not called if init fails. + fn uninit(dev: &mut NetDevice); + + + /// Called when a packet needs to be transmitted. + /// `Ok(())` returns NETDEV_TX_OK, Error maps to `NETDEV_TX_BUSY` + /// Returns NETDEV_TX_OK. Can return NETDEV_TX_BUSY, but you should stop + /// the queue before that can happen; it's for obsolete devices and weird + /// corner cases, but the stack really does a non-trivial amount + /// of useless work if you return NETDEV_TX_BUSY. + fn start_xmit(_skb: SkBuff, _dev: &mut NetDevice) -> NetdevTX { + NetdevTX::TX_OK + } + + fn get_stats64(dev: &NetDevice, stats: &mut RtnlLinkStats64) { } + + fn change_carrier(dev: &mut NetDevice, new_carrier: bool) -> KernelResult<()> { + Err(Error::EINVAL) + } + + fn validate_addr(dev: &NetDevice) -> KernelResult<()> { + Err(Error::EINVAL) + } + + fn set_mac_addr(dev: &mut NetDevice, p: *mut c_types::c_void) -> KernelResult<()> { + Err(Error::EINVAL) + } + + fn set_rx_mode(dev: &mut NetDevice) { + } +} + +/// Iff flags +#[repr(u32)] +#[allow(non_camel_case_types)] +pub enum Iff { + UP = bindings::net_device_flags_IFF_UP, + BROADCAST = bindings::net_device_flags_IFF_BROADCAST, + DEBUG = bindings::net_device_flags_IFF_DEBUG, + LOOPBACK = bindings::net_device_flags_IFF_LOOPBACK, + POINTOPOINT = bindings::net_device_flags_IFF_POINTOPOINT, + NOTRAILERS = bindings::net_device_flags_IFF_NOTRAILERS, + RUNNING = bindings::net_device_flags_IFF_RUNNING, + NOARP = bindings::net_device_flags_IFF_NOARP, + PROMISC = bindings::net_device_flags_IFF_PROMISC, + ALLMULTI = bindings::net_device_flags_IFF_ALLMULTI, + MASTER = bindings::net_device_flags_IFF_MASTER, + SLAVE = bindings::net_device_flags_IFF_SLAVE, + MULTICAST = bindings::net_device_flags_IFF_MULTICAST, + PORTSEL = bindings::net_device_flags_IFF_PORTSEL, + AUTOMEDIA = bindings::net_device_flags_IFF_AUTOMEDIA, + DYNAMIC = bindings::net_device_flags_IFF_DYNAMIC, + + // #if __UAPI_DEF_IF_NET_DEVICE_FLAGS_LOWER_UP_DORMANT_ECHO // TODO: is this needed? + LOWER = bindings::net_device_flags_IFF_LOWER_UP, + DORMANT = bindings::net_device_flags_IFF_DORMANT, + ECHO = bindings::net_device_flags_IFF_ECHO, +} + +/// Iff private flags +#[repr(u32)] +#[allow(non_camel_case_types)] +pub enum IffPriv { + IFF_802_1Q_VLAN = bindings::netdev_priv_flags_IFF_802_1Q_VLAN, // TODO: find a good name without leading 8 + EBRIDGE = bindings::netdev_priv_flags_IFF_EBRIDGE, + BONDING = bindings::netdev_priv_flags_IFF_BONDING, + ISATAP = bindings::netdev_priv_flags_IFF_ISATAP, + WAN_HDLC = bindings::netdev_priv_flags_IFF_WAN_HDLC, + XMIT_DST_RELEASE = bindings::netdev_priv_flags_IFF_XMIT_DST_RELEASE, + DONT_BRIDGE = bindings::netdev_priv_flags_IFF_DONT_BRIDGE, + DISABLE_NETPOLL = bindings::netdev_priv_flags_IFF_DISABLE_NETPOLL, + MACVLAN_PORT = bindings::netdev_priv_flags_IFF_MACVLAN_PORT, + BRIDGE_PORT = bindings::netdev_priv_flags_IFF_BRIDGE_PORT, + OVS_DATAPATH = bindings::netdev_priv_flags_IFF_OVS_DATAPATH, + TX_SKB_SHARING = bindings::netdev_priv_flags_IFF_TX_SKB_SHARING, + UNICAST_FLT = bindings::netdev_priv_flags_IFF_UNICAST_FLT, + TEAM_PORT = bindings::netdev_priv_flags_IFF_TEAM_PORT, + SUPP_NOFCS = bindings::netdev_priv_flags_IFF_SUPP_NOFCS, + LIVE_ADDR_CHANGE = bindings::netdev_priv_flags_IFF_LIVE_ADDR_CHANGE, + MACVLAN = bindings::netdev_priv_flags_IFF_MACVLAN, + XMIT_DST_RELEASE_PERM = bindings::netdev_priv_flags_IFF_XMIT_DST_RELEASE_PERM, + L3MDEV_MASTER = bindings::netdev_priv_flags_IFF_L3MDEV_MASTER, + NO_QUEUE = bindings::netdev_priv_flags_IFF_NO_QUEUE, + OPENVSWITCH = bindings::netdev_priv_flags_IFF_OPENVSWITCH, + L3MDEV_SLAVE = bindings::netdev_priv_flags_IFF_L3MDEV_SLAVE, + TEAM = bindings::netdev_priv_flags_IFF_TEAM, + RXFH_CONFIGURED = bindings::netdev_priv_flags_IFF_RXFH_CONFIGURED, + PHONY_HEADROOM = bindings::netdev_priv_flags_IFF_PHONY_HEADROOM, + MACSEC = bindings::netdev_priv_flags_IFF_MACSEC, + NO_RX_HANDLER = bindings::netdev_priv_flags_IFF_NO_RX_HANDLER, + FAILOVER = bindings::netdev_priv_flags_IFF_FAILOVER, + FAILOVER_SLAVE = bindings::netdev_priv_flags_IFF_FAILOVER_SLAVE, + L3MDEV_RX_HANDLER = bindings::netdev_priv_flags_IFF_L3MDEV_RX_HANDLER, + LIVE_RENAME_OK = bindings::netdev_priv_flags_IFF_LIVE_RENAME_OK, +} + + +pub mod feature { + use crate::bindings; + + use core::ops::{Deref, DerefMut, Add, Sub, AddAssign, SubAssign}; + use core::convert::{From, Into}; + + #[derive(Debug, Clone, Copy)] + pub struct NetIF(u64); + + impl NetIF { + pub const fn new() -> Self { + Self(0) + } + + pub fn add_flag(&mut self, flag: u64) { + self.0 |= flag; + } + + pub fn remove_flag(&mut self, flag: u64) { + self.0 &= !(flag); + } + } + + impl Deref for NetIF { + type Target = u64; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl DerefMut for NetIF { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + impl Add for NetIF { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + Self::from(self.0 | rhs.0) + } + } + + impl Add for NetIF { + type Output = Self; + + fn add(self, rhs: u64) -> Self::Output { + Self::from(self.0 | rhs) + } + } + + impl Sub for NetIF { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + Self::from(self.0 & !rhs.0) + } + } + + impl Sub for NetIF { + type Output = Self; + + fn sub(self, rhs: u64) -> Self::Output { + Self::from(self.0 & !rhs) + } + } + + impl AddAssign for NetIF { + fn add_assign(&mut self, rhs: Self) { + self.0 |= rhs.0 + } + } + + impl AddAssign for NetIF { + fn add_assign(&mut self, rhs: u64) { + self.0 |= rhs + } + } + + impl SubAssign for NetIF { + fn sub_assign(&mut self, rhs: Self) { + self.0 &= !rhs.0 + } + } + + impl SubAssign for NetIF { + fn sub_assign(&mut self, rhs: u64) { + self.0 &= !rhs + } + } + + impl From for NetIF { + fn from(flags: u64) -> Self { + Self(flags) + } + } + + impl Into for NetIF { + fn into(self) -> u64 { + self.0 + } + } + + macro_rules! _netif_f { + ($name:ident, $binding:ident) => { + pub const $name: u64 = 1u64 << $crate::bindings::$binding; + }; + } + + macro_rules! _netif_f_sum { + ($name:ident, $($f:ident),+) => { + pub const $name: u64 = $($crate::net::device::feature::$f |)* 0; + }; + } + + _netif_f!(NETIF_F_SG, NETIF_F_SG_BIT); + _netif_f!(NETIF_F_FRAGLIST, NETIF_F_FRAGLIST_BIT); + _netif_f!(NETIF_F_TSO, NETIF_F_TSO_BIT); + _netif_f!(NETIF_F_TSO6, NETIF_F_TSO6_BIT); + _netif_f!(NETIF_F_TSO_ECN, NETIF_F_TSO_ECN_BIT); + _netif_f!(NETIF_F_TSO_MANGLEID, NETIF_F_TSO_MANGLEID_BIT); + _netif_f!(NETIF_F_GSO_SCTP, NETIF_F_GSO_SCTP_BIT); + _netif_f!(NETIF_F_GSO_UDP_L4, NETIF_F_GSO_UDP_L4_BIT); + _netif_f!(NETIF_F_GSO_FRAGLIST, NETIF_F_GSO_FRAGLIST_BIT); + _netif_f!(NETIF_F_HW_CSUM, NETIF_F_HW_CSUM_BIT); + _netif_f!(NETIF_F_HIGHDMA, NETIF_F_HIGHDMA_BIT); + _netif_f!(NETIF_F_LLTX, NETIF_F_LLTX_BIT); + _netif_f!(NETIF_F_GSO_GRE, NETIF_F_GSO_GRE_BIT); + _netif_f!(NETIF_F_GSO_GRE_CSUM, NETIF_F_GSO_GRE_CSUM_BIT); + _netif_f!(NETIF_F_GSO_IPXIP4, NETIF_F_GSO_IPXIP4_BIT); + _netif_f!(NETIF_F_GSO_IPXIP6, NETIF_F_GSO_IPXIP6_BIT); + _netif_f!(NETIF_F_GSO_UDP_TUNNEL, NETIF_F_GSO_UDP_TUNNEL_BIT); + _netif_f!(NETIF_F_GSO_UDP_TUNNEL_CSUM, NETIF_F_GSO_UDP_TUNNEL_CSUM_BIT); + + _netif_f_sum!(NETIF_F_ALL_TSO, NETIF_F_TSO, NETIF_F_TSO6, NETIF_F_TSO_ECN, NETIF_F_TSO_MANGLEID); + _netif_f_sum!(NETIF_F_GSO_SOFTWARE, NETIF_F_ALL_TSO, NETIF_F_GSO_SCTP, NETIF_F_GSO_UDP_L4, NETIF_F_GSO_FRAGLIST); + _netif_f_sum!(NETIF_F_GSO_ENCAP_ALL, NETIF_F_GSO_GRE, NETIF_F_GSO_GRE_CSUM, NETIF_F_GSO_IPXIP4, NETIF_F_GSO_IPXIP6, NETIF_F_GSO_UDP_TUNNEL, NETIF_F_GSO_UDP_TUNNEL_CSUM); +} + +pub mod helpers { + use super::*; + + pub fn eth_validate_addr(dev: &NetDevice) -> KernelResult<()> { + // SAFETY: dev.ptr is valid if dev is valid + let ret = unsafe { bindings::eth_validate_addr(dev.get_ptr() as *mut bindings::net_device) }; + if ret != 0 { + Err(Error::from_kernel_errno(ret)) + } else { + Ok(()) + } + } + + pub fn eth_mac_addr(dev: &mut NetDevice, p: *mut c_types::c_void) -> KernelResult<()> { + // SAFETY: dev.ptr is valid if dev is valid + let ret = unsafe { bindings::eth_mac_addr( + dev.get_ptr() as *mut bindings::net_device, + p + ) }; + + if ret != 0 { + Err(Error::from_kernel_errno(ret)) + } else { + Ok(()) + } + } +} \ No newline at end of file diff --git a/rust/kernel/net/ethtool.rs b/rust/kernel/net/ethtool.rs new file mode 100644 index 00000000000000..65d8bef867501f --- /dev/null +++ b/rust/kernel/net/ethtool.rs @@ -0,0 +1,262 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Net Device Operations. +//! +//! C header: [`include/linux/netdevice.h`](../../../../include/linux/netdevice.h) + +use core::convert::{TryFrom, TryInto}; +use core::{marker, mem, ops::Deref, ops::DerefMut, pin::Pin, ptr}; + +use crate::bindings; +use crate::{c_types, CStr}; +use crate::error::{Error, KernelResult}; + +use super::device::{NetDeviceAdapter, NetDevice}; + +pub(crate) fn from_kernel_result(r: KernelResult) -> T +where + T: TryFrom, + T::Error: core::fmt::Debug, +{ + match r { + Ok(v) => v, + Err(e) => T::try_from(e.to_kernel_errno()).unwrap(), + } +} + +macro_rules! from_kernel_result { + ($($tt:tt)*) => {{ + from_kernel_result((|| { + $($tt)* + })()) + }}; +} + +unsafe extern "C" fn get_drvinfo_callback(dev: *mut bindings::net_device, info: *mut bindings::ethtool_drvinfo) { + T::EthOps::get_drvinfo( + &NetDevice::::from_ptr(dev), + &mut EthtoolDrvinfo::from_ptr(info) + ); +} + +unsafe extern "C" fn get_ts_info_callback(dev: *mut bindings::net_device, info: *mut bindings::ethtool_ts_info) -> c_types::c_int { + from_kernel_result! { + T::EthOps::get_ts_info( + &NetDevice::::from_ptr(dev), + &mut EthToolTsInfo::from_ptr(info) + )?; + Ok(0) + } +} + +pub(crate) struct EthToolOperationsVtable(marker::PhantomData); + +impl EthToolOperationsVtable { + const VTABLE: bindings::ethtool_ops = bindings::ethtool_ops { + _bitfield_align_1: [], + _bitfield_1: bindings::__BindgenBitfieldUnit::<[u8; 1usize]>::new([0u8; 1usize]), + supported_coalesce_params: 0, + get_drvinfo: if T::EthOps::TO_USE.get_drvinfo { + Some(get_drvinfo_callback::) + } else { + None + }, + get_regs_len: None, + get_regs: None, + get_wol: None, + set_wol: None, + get_msglevel: None, + set_msglevel: None, + nway_reset: None, + get_link: None, + get_link_ext_state: None, + get_eeprom_len: None, + get_eeprom: None, + set_eeprom: None, + get_coalesce: None, + set_coalesce: None, + get_ringparam: None, + set_ringparam: None, + get_pause_stats: None, + get_pauseparam: None, + set_pauseparam: None, + self_test: None, + get_strings: None, + set_phys_id: None, + get_ethtool_stats: None, + begin: None, + complete: None, + get_priv_flags: None, + set_priv_flags: None, + get_sset_count: None, + get_rxnfc: None, + set_rxnfc: None, + flash_device: None, + reset: None, + get_rxfh_key_size: None, + get_rxfh_indir_size: None, + get_rxfh: None, + set_rxfh: None, + get_rxfh_context: None, + set_rxfh_context: None, + get_channels: None, + set_channels: None, + get_dump_flag: None, + get_dump_data: None, + set_dump: None, + get_ts_info: if T::EthOps::TO_USE.get_ts_info { + Some(get_ts_info_callback::) + } else { + None + }, + get_module_info: None, + get_module_eeprom: None, + get_eee: None, + set_eee: None, + get_tunable: None, + set_tunable: None, + get_per_queue_coalesce: None, + set_per_queue_coalesce: None, + get_link_ksettings: None, + set_link_ksettings: None, + get_fecparam: None, + set_fecparam: None, + get_ethtool_phy_stats: None, + get_phy_tunable: None, + set_phy_tunable: None, + }; + + /// Builds an instance of [`struct ethtool_ops`]. + /// + /// # Safety + /// + /// The caller must ensure that the adapter is compatible with the way the device is registered. + pub(crate) const unsafe fn build() -> &'static bindings::ethtool_ops { + &Self::VTABLE + } +} + +/// Represents which fields of [`struct ethtool_ops`] should pe populated with pointers. +pub struct EthToolToUse { + /// The `get_drvinfo` field of [`struct ethtool_ops`]. + pub get_drvinfo: bool, + + pub get_ts_info: bool, +} + +pub const ETH_TOOL_USE_NONE: EthToolToUse = EthToolToUse { + get_drvinfo: false, + get_ts_info: false, +}; + +/// Defines the [`EthToolOps::TO_USE`] field based on a list of fields to be populated. +#[macro_export] +macro_rules! declare_eth_tool_ops { + () => { + const TO_USE: $crate::net::ethtool::EthToolToUse = $crate::net::ethtool::ETH_TOOL_USE_NONE; + }; + ($($i:ident),+) => { + const TO_USE: kernel::net::ethtool::EthToolToUse = + $crate::net::ethtool::EthToolToUse { + $($i: true),+ , + ..$crate::net::ethtool::ETH_TOOL_USE_NONE + }; + }; +} + +pub trait EthToolOps: Send + Sync + Sized { + const TO_USE: EthToolToUse; + + fn get_drvinfo(dev: &NetDevice, info: &mut EthtoolDrvinfo) {} + + fn get_ts_info(dev: &NetDevice, info: &mut EthToolTsInfo) -> KernelResult<()> { + Err(Error::EINVAL) + } +} + +#[repr(transparent)] +pub struct EthToolTsInfo { + ptr: *const bindings::ethtool_ts_info +} + +impl EthToolTsInfo { + /// Constructs a new [`struct ethtool_ts_info`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + pub unsafe fn from_ptr(ptr: *const bindings::ethtool_ts_info) -> Self { + // INVARIANTS: the safety contract ensures the type invariant will hold. + Self { + ptr, + } + } + + pub unsafe fn get_ptr(&self) -> *const bindings::ethtool_ts_info { + self.ptr + } +} + +impl Deref for EthToolTsInfo { + type Target = bindings::ethtool_ts_info; + + fn deref(&self) -> &Self::Target { + // SAFETY: ptr is valid + unsafe { self.ptr.as_ref() }.unwrap() + } +} + +impl DerefMut for EthToolTsInfo { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: ptr is valid + unsafe { (self.ptr as *mut bindings::ethtool_ts_info).as_mut() }.unwrap() + } +} + +pub struct EthtoolDrvinfo { + ptr: *const bindings::ethtool_drvinfo +} + +impl EthtoolDrvinfo { + /// Constructs a new [`struct ethtool_drvinfo`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + pub unsafe fn from_ptr(ptr: *const bindings::ethtool_drvinfo) -> Self { + // INVARIANTS: the safety contract ensures the type invariant will hold. + Self { + ptr, + } + } + + pub unsafe fn get_ptr(&self) -> *const bindings::ethtool_drvinfo { + self.ptr + } +} + +impl Deref for EthtoolDrvinfo { + type Target = bindings::ethtool_drvinfo; + + fn deref(&self) -> &Self::Target { + // SAFETY: ptr is valid + unsafe { self.ptr.as_ref() }.unwrap() + } +} + +impl DerefMut for EthtoolDrvinfo { + fn deref_mut(&mut self) -> &mut Self::Target { + // SAFETY: ptr is valid + unsafe { (self.ptr as *mut bindings::ethtool_drvinfo).as_mut() }.unwrap() + } +} + +pub mod helpers { + use super::*; + + pub fn ethtool_op_get_ts_info(dev: &NetDevice, info: &mut EthToolTsInfo) -> KernelResult<()> { + // SAFETY: dev.ptr is valid if dev is valid + unsafe { bindings::ethtool_op_get_ts_info(dev.get_ptr() as *mut bindings::net_device, info.get_ptr() as *mut bindings::ethtool_ts_info) }; + Ok(()) + } +} \ No newline at end of file diff --git a/rust/kernel/net/mod.rs b/rust/kernel/net/mod.rs new file mode 100644 index 00000000000000..8d039c15ebb8a8 --- /dev/null +++ b/rust/kernel/net/mod.rs @@ -0,0 +1,70 @@ +use core::mem; + +pub mod device; +pub mod ethtool; +pub mod rtnl; +pub mod netlink; +pub mod skbuff; + +#[doc(inline)] +pub use module::rtnl_link_ops; + +#[cfg(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS)] +pub unsafe fn is_multicast_ether_addr(addr: *const u8) -> bool { + let a: u32 = *(addr as *const u32); + + if cfg!(target_endian = "big") { + (0x01 & (a >> (((mem::size_of::() as u32) * 8) -8 ))) != 0 + } else { + (0x01 & a) != 0 + } +} + +#[cfg(not(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))] +pub unsafe fn is_multicast_ether_addr(addr: *const u8) -> bool { + let a: u16 = *(addr as *const u16); + + if cfg!(target_endian = "big") { + (0x01 & (a >> (((mem::size_of::() as u16) * 8) -8 ))) != 0 + } else { + (0x01 & a) != 0 + } +} + +#[cfg(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS)] +pub unsafe fn is_zero_ether_addr(addr: *const u8) -> bool { + *(addr as *const u32) | (*((addr as usize + 4) as *const u16) as u32) == 0 +} + +#[cfg(not(CONFIG_HAVE_EFFICIENT_UNALIGNED_ACCESS))] +pub unsafe fn is_zero_ether_addr(addr: *const u8) -> bool { + *(addr as *const u16) | + *((addr as usize + 2) as *const u16) | + *((addr as usize + 4) as *const u16) == 0 +} + +pub unsafe fn is_valid_ether_addr(addr: *const u8) -> bool { + !is_multicast_ether_addr(addr) && !is_zero_ether_addr(addr) +} + +pub mod prelude { + pub use super::{ + device::{ + NetDeviceOps, + NetDevice, + NetDeviceAdapter, + }, + ethtool::{ + self, + EthToolOps, + }, + rtnl::{ + RtnlLock, + RtnlLinkOps, + }, + skbuff::{ + SkBuff + }, + }; + pub use super::rtnl_link_ops; +} diff --git a/rust/kernel/net/netlink.rs b/rust/kernel/net/netlink.rs new file mode 100644 index 00000000000000..b0f50a98af2291 --- /dev/null +++ b/rust/kernel/net/netlink.rs @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: GPL-2.0 +use alloc::vec::Vec; +use core::ops::{Deref, DerefMut}; + +use crate::{bindings, linked_list::Wrapper}; +use crate::{c_types, CStr}; +use crate::error::{Error, KernelResult}; +use crate::sync::Lock; +use crate::user_ptr::{UserSlicePtr, UserSlicePtrReader, UserSlicePtrWriter}; + +use super::{device::{NetDeviceAdapter, NetDevice}, is_valid_ether_addr}; + +pub const ETH_ALEN: u16 = bindings::ETH_ALEN as u16; + +const NLA_HDRLEN: i32 = bindings::BINDINGS_NLA_HDRLEN; +const __IFLA_MAX: usize = bindings::__IFLA_MAX as usize; + +#[repr(transparent)] +pub struct NlAttr(*const bindings::nlattr); + +impl NlAttr { + pub fn is_null(&self) -> bool { + self.0.is_null() + } + + pub fn nla_len(&self) -> u16 { + if self.is_null() { + return 0; + } + + // NO-PANIC: self is valid and not null + // SAFETY: ptr is valid if self is valid + let nlattr = unsafe { self.0.as_ref() }.unwrap(); + nlattr.nla_len - NLA_HDRLEN as u16 + } + + /// Constructs a new [`struct nlattr`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + pub unsafe fn from_ptr(ptr: *const bindings::nlattr) -> Self { + Self(ptr) + } + /*pub unsafe fn from_ptr(ptr: *const bindings::nlattr) -> &'static mut Self { + (ptr as *mut NlAttr).as_mut().unwrap() + }*/ + + pub unsafe fn data(&self) -> *const i8 { + ((self.0 as usize) + NLA_HDRLEN as usize) as *const i8 + } + + pub fn is_valid_ether_addr(&self) -> bool { + // SAFETY: self.o is valid if self is valid + unsafe { + let data = self.data() as *const u8; + super::is_valid_ether_addr(data) + } + } + + unsafe fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +#[repr(transparent)] +pub struct NlExtAck(*const bindings::netlink_ext_ack); + +impl NlExtAck { + /// Constructs a new [`struct netlink_ext_ack`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + pub unsafe fn from_ptr(ptr: *const bindings::netlink_ext_ack) -> Self { + Self(ptr) + } +} + + +#[repr(transparent)] +pub struct NlAttrVec { + ptr: *const *const bindings::nlattr, +} + +impl NlAttrVec { + pub fn get(&self, offset: u32) -> Option { + if offset > __IFLA_MAX as u32 { + return None; + } + + let vec = unsafe { + &*(self.ptr as *const [NlAttr; __IFLA_MAX]) + }; + let nlattr = &vec[offset as usize]; + if nlattr.is_null() { + None + } else { + Some(unsafe { nlattr.clone() }) + } + } + + pub unsafe fn from_ptr(ptr: *const *const bindings::nlattr) -> Self { + /*let vec = *(ptr as *const [NlAttr; __IFLA_MAX]); + Self(vec)*/ + Self { + ptr, + } + } +} + +/*pub struct NlAttrVec<'a>(&'a mut [NlAttr]); + +impl<'a> NlAttrVec<'a> { + pub fn get(&self, offset: u32) -> Option { + if offset > bindings::__IFLA_MAX { + return None; + } + + let nlattr = &self.0[offset as usize]; + if nlattr.is_null() { + None + } else { + Some(unsafe { nlattr.clone() }) + } + } + + /// Constructs a new [`struct nlattr[]`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + /// The pointer `ptr` must be valid for the size of `__IFLA_MAX` * `mem::size_of` + pub unsafe fn from_ptr(ptr: *const *const bindings::nlattr) -> Self { + // TODO: is this correct? + Self(core::slice::from_raw_parts_mut(ptr as *mut NlAttr, bindings::__IFLA_MAX as usize)) + } +}*/ \ No newline at end of file diff --git a/rust/kernel/net/rtnl.rs b/rust/kernel/net/rtnl.rs new file mode 100644 index 00000000000000..b2a8e4a0502a06 --- /dev/null +++ b/rust/kernel/net/rtnl.rs @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! Net Device Operations. +//! +//! C header: [`include/linux/rtnetlink.h`](../../../../include/linux/rtnetlink.h) + +use core::convert::{TryFrom, TryInto}; +use core::{marker, mem, ops::Deref, ops::DerefMut, pin::Pin, ptr}; +use core::cell::UnsafeCell; + +use crate::bindings; +use crate::{c_types, CStr}; +use crate::error::{Error, KernelResult}; +use crate::sync::Lock; +use crate::user_ptr::{UserSlicePtr, UserSlicePtrReader, UserSlicePtrWriter}; + +use super::device::{NetDeviceAdapter, NetDevice}; + +// TODO: inner bool, to allow other unlock mechanism? +#[must_use = "the rtnl unlocks immediately when the guard is unused"] +pub struct RtnlLock{ + _private: () +} + +impl RtnlLock { + pub fn lock() -> Self { + // SAFETY: C function without parameters + unsafe { bindings::rtnl_lock() }; + + Self { + _private: (), + } + } +} + +impl Drop for RtnlLock { + fn drop(&mut self) { + // SAFETY: C function without parameters + unsafe { bindings::rtnl_unlock() }; + } +} + +pub const RTNL_LINK_OPS_EMPTY: bindings::rtnl_link_ops = bindings::rtnl_link_ops { + list: bindings::list_head { + next: ptr::null::() as *mut bindings::list_head, + prev: ptr::null::() as *mut bindings::list_head, + }, + kind: ptr::null::(), + priv_size: 0, + setup: None, + maxtype: 0, + policy: ptr::null::(), + validate: None, + newlink: None, + changelink: None, + dellink: None, + get_size: None, + fill_info: None, + get_xstats_size: None, + fill_xstats: None, + get_num_tx_queues: None, + get_num_rx_queues: None, + slave_maxtype: 0, + slave_policy: ptr::null::(), + slave_changelink: None, + get_slave_size: None, + fill_slave_info: None, + get_link_net: None, + get_linkxstats_size: None, + fill_linkxstats: None, +}; + +#[repr(transparent)] +pub struct RtnlLinkOps(pub bindings::rtnl_link_ops); + +unsafe impl Sync for RtnlLinkOps {} + +impl RtnlLinkOps { + pub fn register(&self) -> KernelResult { + // SAFETY: ptr of self is valid if self is valid + let ret = unsafe { + let ptr = self.get_ptr(); + + bindings::rtnl_link_register(ptr as *mut bindings::rtnl_link_ops) + }; + + if ret != 0 { + Err(Error::from_kernel_errno(ret)) + } else { + Ok(()) + } + } + + pub unsafe fn get_ptr(&self) -> *const bindings::rtnl_link_ops { + self as *const _ as *const bindings::rtnl_link_ops + } + + pub fn unregister(&self) { + let ptr = self as *const _ as *mut bindings::rtnl_link_ops; + + // SAFETY: ptr is valid if self is valid + unsafe { bindings::rtnl_link_unregister(ptr) }; + } +} + +#[repr(transparent)] +pub struct RtnlLinkStats64 { + ptr: *const bindings::rtnl_link_stats64 +} + +impl RtnlLinkStats64 { + pub fn dev_read(&mut self, dev: &NetDevice) { + let stats = self.deref_int(); + // SAFETY: call to C function + unsafe { + bindings::dev_lstats_read( + dev.get_ptr() as *mut bindings::net_device, + &stats.tx_packets as *const u64 as *mut u64, + &stats.tx_bytes as *const u64 as *mut u64 + ); + } + } + + /// Constructs a new [`struct rtnl_link_stats64`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + pub unsafe fn from_ptr(ptr: *const bindings::rtnl_link_stats64) -> Self { + // INVARIANTS: the safety contract ensures the type invariant will hold. + Self { ptr } + } + + + fn deref_int(&self) -> &bindings::rtnl_link_stats64{ + // SAFETY: self.ptr is valid if self is valid + unsafe { + self.ptr.as_ref() + }.unwrap() + } +} \ No newline at end of file diff --git a/rust/kernel/net/skbuff.rs b/rust/kernel/net/skbuff.rs new file mode 100644 index 00000000000000..ab44598f58e202 --- /dev/null +++ b/rust/kernel/net/skbuff.rs @@ -0,0 +1,151 @@ +use core::{ops::Drop, ptr}; + +use crate::bindings; + +/// Wraps the kernel's `struct sk_buff`. +/// +/// # Invariants +/// +/// The pointer [`SkBuff::ptr`] is non-null and valid. +#[repr(transparent)] +pub struct SkBuff { + ptr: *const bindings::sk_buff, +} + +impl SkBuff { + + #[cfg(CONFIG_NETWORK_PHY_TIMESTAMPING)] + pub fn clone_tx_timestamp(&mut self) { + // SAFETY: self.ptr is valid if self is valid + unsafe { + bindings::skb_clone_tx_timestamp(self.ptr as *mut bindings::sk_buff); + } + } + + #[cfg(not(CONFIG_NETWORK_PHY_TIMESTAMPING))] + pub fn clone_tx_timestamp(&mut self) { + // NOOP + } + + /// tx_timestamp - Driver hook for transmit timestamping + /// + /// Ethernet MAC Drivers should call this function in their hard_xmit() + /// function immediately before giving the sk_buff to the MAC hardware. + /// + /// Specifically, one should make absolutely sure that this function is + /// called before TX completion of this packet can trigger. Otherwise + /// the packet could potentially already be freed. + pub fn tx_timestamp(&mut self) { + self.clone_tx_timestamp(); + if (self.shinfo().tx_flags() as u32 & bindings::SKBTX_SW_TSTAMP != 0) { + unsafe { + bindings::skb_tstamp_tx(self.ptr as *mut bindings::sk_buff, ptr::null_mut()); + } + // skb_tstamp_tx(skb, NULL); + } + } + + pub fn len(&self) -> u32 { + let skb = self.deref_int(); + skb.len + } + + /// Constructs a new [`struct sk_buff`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + pub unsafe fn from_ptr(ptr: *const bindings::sk_buff) -> Self { + // INVARIANTS: the safety contract ensures the type invariant will hold. + Self { ptr } + } + + fn deref_int(&self) -> &bindings::sk_buff { + // SAFETY: self.ptr is valid if self is valid + unsafe { + self.ptr.as_ref() + }.unwrap() + } + + pub fn shinfo(&self) -> SkbSharedInfo { + // SAFETY: self.ptr is valid if self is valid + unsafe { + let info = self.shinfo_int(); + SkbSharedInfo::from_ptr(info) + } + } + + unsafe fn shinfo_int(&self) -> *mut bindings::skb_shared_info { + self.end_pointer() as *mut bindings::skb_shared_info + } + + // NET_SKBUFF_DATA_USES_OFFSET + #[cfg(target_pointer_width = "64")] + fn end_pointer(&self) -> *mut u8 { + let sk_reff = self.deref_int(); + (sk_reff.head as usize + sk_reff.end as usize) as *mut u8 + } + + // !NET_SKBUFF_DATA_USES_OFFSET + #[cfg(not(target_pointer_width = "64"))] + fn end_pointer(&self) -> *mut u8 { + let sk_reff = self.deref_int(); + (sk_reff.end) as *mut u8 + } +} + +impl Drop for SkBuff { + + #[cfg(CONFIG_TRACEPOINTS)] + fn drop(&mut self) { + // SAFETY: self.ptr is valid if self is valid + unsafe { + bindings::consume_skb(self.ptr as *mut bindings::sk_buff); + } + } + + #[cfg(not(CONFIG_TRACEPOINTS))] + fn drop(&mut self){ + // SAFETY: self.ptr is valid if self is valid + unsafe { + bindings::kfree_skb(self.ptr as *mut bindings::sk_buff); + } + } +} + + +/// Wraps the kernel's `struct skb_shared_info`. +/// +/// # Invariants +/// +/// The pointer [`SkbSharedInfo::ptr`] is non-null and valid. +#[repr(transparent)] +pub struct SkbSharedInfo { + ptr: *const bindings::skb_shared_info, +} + +impl SkbSharedInfo { + + pub fn tx_flags(&self) -> u8 { + let ref_skb = self.deref_int(); + ref_skb.tx_flags + } + + /// Constructs a new [`struct skb_shared_info`] wrapper. + /// + /// # Safety + /// + /// The pointer `ptr` must be non-null and valid for the lifetime of the object. + pub unsafe fn from_ptr(ptr: *const bindings::skb_shared_info) -> Self { + // INVARIANTS: the safety contract ensures the type invariant will hold. + Self { ptr } + } + + + fn deref_int(&self) -> &bindings::skb_shared_info { + // SAFETY: self.ptr is valid if self is valid + unsafe { + self.ptr.as_ref() + }.unwrap() + } +} \ No newline at end of file diff --git a/rust/module.rs b/rust/module.rs index 471db3df184280..ce12460d92bcca 100644 --- a/rust/module.rs +++ b/rust/module.rs @@ -103,6 +103,14 @@ fn expect_end(it: &mut token_stream::IntoIter) { } } +fn get_ident(it: &mut token_stream::IntoIter, expected_name: &str) -> String { + assert_eq!(expect_ident(it), expected_name); + assert_eq!(expect_punct(it), ':'); + let ident = expect_ident(it); + assert_eq!(expect_punct(it), ','); + ident +} + fn get_literal(it: &mut token_stream::IntoIter, expected_name: &str) -> String { assert_eq!(expect_ident(it), expected_name); assert_eq!(expect_punct(it), ':'); @@ -854,3 +862,188 @@ pub fn module_misc_device(ts: TokenStream) -> TokenStream { .parse() .expect("Error parsing formatted string into token stream.") } + +/// Declares a rtnl link operation table. +/// +/// The `type` argument should match the type used for `T` in `NetDevice`. +/// +/// # Examples +/// +/// ```rust,no_run +/// use kernel::prelude::*; +/// use kernel::net::prelude::*; +/// +/// fn setup(dev: &mut NetDevice) { +/// dev.ether_setup(); +/// // ... +/// } +/// +/// rtnl_link_ops! { +/// kind: b"dummy_rs", +/// type: DummyRsDev, +/// setup: setup, +/// maxtype: 20, +/// } +/// +/// struct DummyRsDev; +/// +/// impl NetDeviceOps for DummyRsDev { +/// kernel::declare_net_device_ops!(); +/// +/// fn init(dev: &NetDevice) -> KernelResult<()> { +/// Ok(()) +/// } +/// +/// fn uninit(dev: &NetDevice) { +/// +/// } +/// } +/// +/// impl EthToolOps for DummyRsDev { +/// kernel::declare_eth_tool_ops!(); +/// } +/// +/// fn call_from_module_init() -> KernelResult<()> { +/// let mut dev = NetDevice::new(DummyRsDev, kernel::cstr!("dummyrs%d"), kernel::net::device::NetNameAssingType::Enum, 1, 1)?; +/// +/// dev.register(); +/// dev.set_rtnl_ops(dummy_rs_rtnl_link_ops); +/// +/// Ok(()) +/// } +/// +/// ``` +#[proc_macro] +pub fn rtnl_link_ops(ts: TokenStream) -> TokenStream { + let mut it = ts.into_iter(); + let literals = vec!["maxtype", "policy", "slave_maxtype", "slave_policy"]; + + let mut found_idents = Vec::new(); + + let kind = get_byte_string(&mut it, "kind"); + let netdevice = get_ident(&mut it, "type"); + + let mut callbacks = String::new(); + let mut fields = String::new(); + + loop { + let name = match it.next() { + Some(TokenTree::Ident(ident)) => ident.to_string(), + Some(_) => panic!("Expected Ident or End"), + None => break, + }; + + assert_eq!(expect_punct(&mut it), ':'); + + if literals.contains(&name.as_str()) { + let literal = expect_literal(&mut it); + fields.push_str(&format!("{name}: {literal},\n", name = name, literal = literal)); + } else { + let func = expect_ident(&mut it); + callbacks.push_str(&build_rtnl_links_callback(&name, &netdevice, &func, &kind)); + found_idents.push(name); + } + + assert_eq!(expect_punct(&mut it), ','); + } + + expect_end(&mut it); + + let callback_fields = found_idents.iter().map(|name| + format!("{}: Some(__rtnl_link_{}_callback_{}),", name, name, kind) + ).collect::>().join("\n"); + + let ops_struct = format!(r#" + #[doc(hidden)] + #[used] + pub static mut {kind}_rtnl_link_ops: kernel::net::rtnl::RtnlLinkOps = kernel::net::rtnl::RtnlLinkOps(kernel::bindings::rtnl_link_ops {{ + priv_size: core::mem::size_of::<<{netdevice} as kernel::net::device::NetDeviceAdapter>::Inner>(), + kind: b"{kind}\0".as_ptr() as *const i8, + {callback_fields} + {fields} + ..kernel::net::rtnl::RTNL_LINK_OPS_EMPTY + }}); + "#, + kind = kind, + netdevice = netdevice, + callback_fields = callback_fields, + fields = fields, + ); + + format!( + r#" + {} + + // TODO: add #[link_section = ".data.read_mosty"] if x86 + #[cfg(any(CONFIG_X86, CONFIG_SPARC64))] + #[link_section = ".data.read_mostly"] + {} + + #[cfg(not(any(CONFIG_X86, CONFIG_SPARC64)))] + {} + "#, + callbacks, + ops_struct, + ops_struct, + ).parse().expect("Error parsing formatted string into token stream.") +} + +struct RtnlLinkValues { + callback_params: String, + return_type: String, + wrapper_before: String, + wrapper_after: String, + params: String, +} + +impl RtnlLinkValues { + fn new(callback_params: &str, wrapper_before: &str, params: &str) -> Self { + Self { + callback_params: callback_params.to_string(), + return_type: "()".to_string(), + wrapper_before: wrapper_before.to_string(), + wrapper_after: "".to_string(), + params: params.to_string(), + } + } +} + +fn get_rtnl_links_values(name: &str, netdevice: &str) -> RtnlLinkValues { + let setup_dev = format!("let mut dev = kernel::net::device::NetDevice::<{}>::from_ptr(dev);", netdevice); + match name { + "setup" => RtnlLinkValues::new("dev: *mut kernel::bindings::net_device", &setup_dev, "&mut dev"), + "validate" => RtnlLinkValues { + callback_params: "tb: *mut *mut kernel::bindings::nlattr, data: *mut *mut kernel::bindings::nlattr, extack: *mut kernel::bindings::netlink_ext_ack".to_string(), + return_type: "kernel::c_types::c_int".to_string(), + wrapper_before: r#"kernel::c_from_kernel_result! { + let tb = kernel::net::netlink::NlAttrVec::from_ptr(tb as *const *const kernel::bindings::nlattr); + let data = kernel::net::netlink::NlAttrVec::from_ptr(data as *const *const kernel::bindings::nlattr); + let extack = kernel::net::netlink::NlExtAck::from_ptr(extack); + "#.to_string(), + wrapper_after: "?; Ok(0) }".to_string(), + params: "&tb, &data, &extack".to_string(), + }, + _ => panic!("invalid rtnl_link_ops function '{}'", name), + } +} + +fn build_rtnl_links_callback(name: &str, netdevice: &str, func: &str, kind: &str) -> String { + let values = get_rtnl_links_values(name, netdevice); + format!(r#" + #[doc(hidden)] + pub unsafe extern "C" fn __rtnl_link_{name}_callback_{kind}({cb_params}) -> {cb_return} {{ + {cb_before} + {cb_func}({cb_r_params}) + {cb_after} + }} + "#, + name = name, + kind = kind, + cb_params = values.callback_params, + cb_return = values.return_type, + cb_before = values.wrapper_before, + cb_func = func, + cb_r_params = values.params, + cb_after = values.wrapper_after, + ) +}