From 223c178bf64edf23d9992324594d52302e6628f7 Mon Sep 17 00:00:00 2001
From: Wedson Almeida Filho <wedsonaf@google.com>
Date: Thu, 11 Nov 2021 18:36:03 +0000
Subject: [PATCH] rust: add `bit` function.

It behaves similarly to C's `BIT` macro. It is used to seamlessly
convert C code that uses this macro, for example the PL061 driver.

Some of this was discussed in Zulip. We couldn't find a simpler solution
that offered the same ergonomics.

Signed-off-by: Wedson Almeida Filho <wedsonaf@google.com>
---
 rust/kernel/lib.rs   |   2 +-
 rust/kernel/types.rs | 119 ++++++++++++++++++++++++++++++++++++++++---
 2 files changed, 114 insertions(+), 7 deletions(-)

diff --git a/rust/kernel/lib.rs b/rust/kernel/lib.rs
index bc5ac70c3ab11e..de00329b5aebab 100644
--- a/rust/kernel/lib.rs
+++ b/rust/kernel/lib.rs
@@ -91,7 +91,7 @@ pub mod user_ptr;
 pub use build_error::build_error;
 
 pub use crate::error::{to_result, Error, Result};
-pub use crate::types::{bits_iter, Mode, Opaque, ScopeGuard};
+pub use crate::types::{bit, bits_iter, Mode, Opaque, ScopeGuard};
 
 use core::marker::PhantomData;
 
diff --git a/rust/kernel/types.rs b/rust/kernel/types.rs
index 5916dff3a45a3b..18fcca75314a22 100644
--- a/rust/kernel/types.rs
+++ b/rust/kernel/types.rs
@@ -302,27 +302,134 @@ impl<T> Opaque<T> {
     }
 }
 
+/// A bitmask.
+///
+/// It has a restriction that all bits must be the same, except one. For example, `0b1110111` and
+/// `0b1000` are acceptable masks.
+#[derive(Clone, Copy)]
+pub struct Bit<T> {
+    index: T,
+    inverted: bool,
+}
+
+/// Creates a bit mask with a single bit set.
+///
+/// # Examples
+///
+/// ```
+/// # use kernel::prelude::*;
+/// # use kernel::bit;
+/// let mut x = 0xfeu32;
+///
+/// assert_eq!(x & bit(0), 0);
+/// assert_eq!(x & bit(1), 2);
+/// assert_eq!(x & bit(2), 4);
+/// assert_eq!(x & bit(3), 8);
+///
+/// x |= bit(0);
+/// assert_eq!(x, 0xff);
+///
+/// x &= !bit(1);
+/// assert_eq!(x, 0xfd);
+///
+/// x &= !bit(7);
+/// assert_eq!(x, 0x7d);
+///
+/// let y: u64 = bit(34).into();
+/// assert_eq!(y, 0x400000000);
+///
+/// assert_eq!(y | bit(35), 0xc00000000);
+/// ```
+pub fn bit<T: Copy>(index: T) -> Bit<T> {
+    Bit {
+        index,
+        inverted: false,
+    }
+}
+
+impl<T: Copy> ops::Not for Bit<T> {
+    type Output = Self;
+    fn not(self) -> Self {
+        Self {
+            index: self.index,
+            inverted: !self.inverted,
+        }
+    }
+}
+
 /// Implemented by integer types that allow counting the number of trailing zeroes.
 pub trait TrailingZeros {
     /// Returns the number of trailing zeroes in the binary representation of `self`.
     fn trailing_zeros(&self) -> u32;
 }
 
-macro_rules! impl_trailing_zeros {
+macro_rules! define_unsigned_number_traits {
     ($type_name:ty) => {
         impl TrailingZeros for $type_name {
             fn trailing_zeros(&self) -> u32 {
                 <$type_name>::trailing_zeros(*self)
             }
         }
+
+        impl<T: Copy> core::convert::From<Bit<T>> for $type_name
+        where
+            Self: ops::Shl<T, Output = Self> + core::convert::From<u8> + ops::Not<Output = Self>,
+        {
+            fn from(v: Bit<T>) -> Self {
+                let c = Self::from(1u8) << v.index;
+                if v.inverted {
+                    !c
+                } else {
+                    c
+                }
+            }
+        }
+
+        impl<T: Copy> ops::BitAnd<Bit<T>> for $type_name
+        where
+            Self: ops::Shl<T, Output = Self> + core::convert::From<u8>,
+        {
+            type Output = Self;
+            fn bitand(self, rhs: Bit<T>) -> Self::Output {
+                self & Self::from(rhs)
+            }
+        }
+
+        impl<T: Copy> ops::BitOr<Bit<T>> for $type_name
+        where
+            Self: ops::Shl<T, Output = Self> + core::convert::From<u8>,
+        {
+            type Output = Self;
+            fn bitor(self, rhs: Bit<T>) -> Self::Output {
+                self | Self::from(rhs)
+            }
+        }
+
+        impl<T: Copy> ops::BitAndAssign<Bit<T>> for $type_name
+        where
+            Self: ops::Shl<T, Output = Self> + core::convert::From<u8>,
+        {
+            fn bitand_assign(&mut self, rhs: Bit<T>) {
+                *self &= Self::from(rhs)
+            }
+        }
+
+        impl<T: Copy> ops::BitOrAssign<Bit<T>> for $type_name
+        where
+            Self: ops::Shl<T, Output = Self> + core::convert::From<u8>,
+        {
+            fn bitor_assign(&mut self, rhs: Bit<T>) {
+                *self |= Self::from(rhs)
+            }
+        }
     };
 }
 
-impl_trailing_zeros!(u8);
-impl_trailing_zeros!(u16);
-impl_trailing_zeros!(u32);
-impl_trailing_zeros!(u64);
-impl_trailing_zeros!(usize);
+define_unsigned_number_traits!(u8);
+define_unsigned_number_traits!(u16);
+define_unsigned_number_traits!(u32);
+define_unsigned_number_traits!(u64);
+define_unsigned_number_traits!(usize);
 
 /// Returns an iterator over the set bits of `value`.
 ///