Skip to content

Commit

Permalink
Define signed integer
Browse files Browse the repository at this point in the history
  • Loading branch information
minseongg committed Nov 1, 2024
1 parent 2f20e3c commit 4681dc1
Show file tree
Hide file tree
Showing 10 changed files with 207 additions and 167 deletions.
4 changes: 2 additions & 2 deletions hazardflow-designs/src/cpu/multiplier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ pub fn muldiv<P: Copy, R: Copy>(
let mpcand = s.divisor;

let prod = {
let mpcand = mpcand.sext::<34>();
let accum = accum.sext::<34>();
let mpcand = U::from(S::from(mpcand).sext::<34>());
let accum = U::from(S::from(accum).sext::<34>());

if !mplier[0] {
accum
Expand Down
26 changes: 19 additions & 7 deletions hazardflow-designs/src/gemmini/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,21 @@
use super::*;

/// MAC unit (computes `a * b + c`).
///
/// It preserves the signedness of operands.
pub fn mac(a: S<8>, b: S<8>, c: S<32>) -> S<OUTPUT_BITS> {
let a = u32::from(U::from(a.sext::<32>())) as i32;
let b = u32::from(U::from(b.sext::<32>())) as i32;
let c = u32::from(U::from(c)) as i32;
S::from((a * b + c).into_u())
}

/// Rounding shift (round-to-nearest-even)
/// <https://github.com/ucb-bar/gemmini/blob/be2e9f26181658895ebc7ca7f7d6be6210f5cdef/src/main/scala/gemmini/Arithmetic.scala#L97C7-L97C22>
/// <https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc#38-vector-fixed-point-rounding-mode-register-vxrm>
pub fn rounding_shift(val: U<32>, shamt: U<5>) -> U<32> {
pub fn rounding_shift(val: S<32>, shamt: U<5>) -> S<32> {
let val = U::from(val);
let val_i32 = u32::from(val) as i32; // $signed(c1)
let shamt_usize = u32::from(shamt) as usize;
let round_down_shifted = val_i32 >> u32::from(shamt);
Expand All @@ -24,20 +35,21 @@ pub fn rounding_shift(val: U<32>, shamt: U<5>) -> U<32> {
// d != 0 && v[d-1] && (v[d-2:0]!=0 | v[d])
let r = (nonzero_shamt & val[shamt_usize - 1] & (zeros | val[shamt_usize])) as i32;

(round_down_shifted + r).into_u()
S::from((round_down_shifted + r).into_u())
}

/// Same as `clippedToWidthOf` function.
/// <https://github.com/ucb-bar/gemmini/blob/be2e9f26181658895ebc7ca7f7d6be6210f5cdef/src/main/scala/gemmini/Arithmetic.scala#L122C20-L126>
pub fn clip_with_saturation<const N: usize, const M: usize>(val: U<N>) -> U<M>
pub fn clip_with_saturation<const N: usize, const M: usize>(val: S<N>) -> S<M>
where
[(); M - 1]:,
[(); (M - 1) + 1]:,
{
let val = u32::from(val) as i32;
let val = u32::from(U::from(val)) as i32;

let sat_max = u32::from(U::<M>::signed_max()) as i32;
let sat_min = u32::from(U::<M>::signed_min().resize::<20>().sext::<32>()) as i32;
// TODO: Fix `sat_min` logic.
let sat_max = u32::from(U::from(S::<M>::signed_max())) as i32;
let sat_min = u32::from(U::from(S::<M>::signed_min().resize::<20>().sext::<32>())) as i32;
let clipped = if val > sat_max {
sat_max
} else if val < sat_min {
Expand All @@ -46,5 +58,5 @@ where
val
};

clipped.into_u()
S::from(clipped.into_u())
}
6 changes: 3 additions & 3 deletions hazardflow-designs/src/gemmini/execute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,8 @@ fn compute_write_signal(resp: (MeshRespExtended, (Dataflow, U<3>, U<16>))) -> Wr
#[allow(clippy::identity_op)]
fn clip_with_saturation(val: U<20>) -> U<8> {
let val_msb = val[20 - 1];
let sat_max = U::<8>::signed_max();
let sat_min = U::<8>::signed_min();
let sat_max = U::from(S::<8>::signed_max());
let sat_min = U::from(S::<8>::signed_min());

// TODO: Better way for signed comparison? Modify compiler for signed comparison.
if !val_msb && val > sat_max.resize() {
Expand Down Expand Up @@ -796,7 +796,7 @@ fn acc_write_req(resp: (MeshRespExtended, (Dataflow, U<3>, U<16>)), bank_i: U<1>
let write_signals = compute_write_signal(resp);
let resp = resp.0;

let wdata = resp.mesh_resp.data.map(|v| v.sext::<32>());
let wdata = resp.mesh_resp.data.map(|v| U::from(S::from(v).sext::<32>()));
let wmask = write_signals.w_mask.map(|v| v.repeat::<4>()).concat();

if write_signals.start_array_outputting
Expand Down
26 changes: 11 additions & 15 deletions hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub struct PeRowData {
/// A.
///
/// Represents the activation value.
pub a: U<INPUT_BITS>,
pub a: S<INPUT_BITS>,
}

/// PE column data signals.
Expand All @@ -19,12 +19,12 @@ pub struct PeColData {
/// B.
///
/// Represents the weight value (in OS dataflow) or the above PE's MAC result (in WS dataflow).
pub b: U<OUTPUT_BITS>,
pub b: S<OUTPUT_BITS>,

/// D.
///
/// Represents the preloading bias value (in OS dataflow) or the preloading weight value (in WS dataflow).
pub d: U<OUTPUT_BITS>,
pub d: S<OUTPUT_BITS>,
}

/// PE column control signals.
Expand Down Expand Up @@ -89,10 +89,10 @@ pub enum Propagate {
#[derive(Debug, Default, Clone, Copy)]
pub struct PeS {
/// Register 1.
pub reg1: U<32>,
pub reg1: S<32>,

/// Register 2.
pub reg2: U<32>,
pub reg2: S<32>,

/// The propagate value comes from the previous input.
///
Expand All @@ -102,7 +102,7 @@ pub struct PeS {

impl PeS {
/// Creates a new PE state.
pub fn new(reg1: U<32>, reg2: U<32>, propagate: Propagate) -> Self {
pub fn new(reg1: S<32>, reg2: S<32>, propagate: Propagate) -> Self {
Self { reg1, reg2, propagate }
}

Expand All @@ -113,7 +113,7 @@ impl PeS {
/// - `preload`: Bias value for the next operation.
/// - `partial_sum`: MAC result of the current operation.
/// - `propagate`: Propagate value.
pub fn new_os(preload: U<32>, partial_sum: U<32>, propagate: Propagate) -> Self {
pub fn new_os(preload: S<32>, partial_sum: S<32>, propagate: Propagate) -> Self {
match propagate {
Propagate::Reg1 => PeS::new(preload, partial_sum, propagate),
Propagate::Reg2 => PeS::new(partial_sum, preload, propagate),
Expand All @@ -127,7 +127,7 @@ impl PeS {
/// - `preload`: Weight value for the next operation.
/// - `weight`: Weight value for the current operation.
/// - `propagate`: Propagate value.
pub fn new_ws(preload: U<32>, weight: U<32>, propagate: Propagate) -> Self {
pub fn new_ws(preload: S<32>, weight: S<32>, propagate: Propagate) -> Self {
match propagate {
Propagate::Reg1 => PeS::new(preload, weight, propagate),
Propagate::Reg2 => PeS::new(weight, preload, propagate),
Expand All @@ -138,18 +138,14 @@ impl PeS {
/// MAC unit (computes `a * b + c`).
///
/// It preserves the signedness of operands.
fn mac(a: U<8>, b: U<8>, c: U<32>) -> U<OUTPUT_BITS> {
let a = u32::from(a.sext::<32>()) as i32;
let b = u32::from(b.sext::<32>()) as i32;
let c = u32::from(c) as i32;

(a * b + c).into_u()
fn mac(a: S<8>, b: S<8>, c: S<32>) -> S<OUTPUT_BITS> {
super::arithmetic::mac(a, b, c)
}

/// Performs right-shift (`val >> shamt`) and then clips to `OUTPUT_BITS`.
///
/// It preserves the signedness of `val`.
fn shift_and_clip(val: U<32>, shamt: U<5>) -> U<OUTPUT_BITS> {
fn shift_and_clip(val: S<32>, shamt: U<5>) -> S<OUTPUT_BITS> {
let shifted = rounding_shift(val, shamt);
super::arithmetic::clip_with_saturation::<32, OUTPUT_BITS>(shifted)
}
Expand Down
2 changes: 2 additions & 0 deletions hazardflow-designs/src/std/value/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ use crate::prelude::*;
mod array;
mod bounded;
mod option;
mod sint;
mod uint;

pub use array::*;
pub use bounded::*;
pub use option::*;
pub use sint::*;
pub use uint::*;

/// Don't care value.
Expand Down
63 changes: 63 additions & 0 deletions hazardflow-designs/src/std/value/sint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
//! Signed integer.
use super::*;

/// An signed integer with bitwidth `N`.
#[derive(Debug, Default, Clone, Copy)]
pub struct S<const N: usize>(U<N>);

impl<const N: usize> S<N> {
/// Sign extends `S<N>` to `S<M>`.
///
/// It panics when `M < N`.
#[allow(clippy::identity_op)]
pub fn sext<const M: usize>(self) -> S<M>
where
[(); (M - N) * 1]:,
[(); N + (M - N)]:,
{
if M >= N {
let msb_arr: Array<bool, { M - N }> = self.0.clip_const::<1>(N - 1).repeat::<{ M - N }>().concat().resize();
S(self.0.append(msb_arr).resize::<M>())
} else {
panic!("M should be larger than N")
}
}

/// Resizes the bitwidth.
///
/// It does not preserves the signedness.
pub fn resize<const M: usize>(self) -> S<M> {
S::from(U::from(self).resize())
}

/// Returns the maximum value of an `N` bit signed value. (i.e., 2^(`N` - 1) - 1)
pub fn signed_max() -> S<N>
where
[(); N - 1]:,
[(); (N - 1) + 1]:,
{
S::from(U::<N>::unsigned_max().clip_const::<{ N - 1 }>(0).append(U::<1>::from(0)).resize::<N>())
}

/// Returns the minimum value of an `N` bit unsigned value. (i.e., -2^(`N` - 1))
pub fn signed_min() -> S<N>
where
[(); N - 1]:,
[(); (N - 1) + 1]:,
{
S::from(U::<{ N - 1 }>::from(0).append(U::<1>::from(1)).resize::<N>())
}
}

impl<const N: usize> From<U<N>> for S<N> {
fn from(value: U<N>) -> S<N> {
S(value)
}
}

impl<const N: usize> From<S<N>> for U<N> {
fn from(value: S<N>) -> U<N> {
value.0
}
}
35 changes: 1 addition & 34 deletions hazardflow-designs/src/std/value/uint.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
//! Integer.
//! Unsigned integer.
use core::cmp::Ordering;
use core::ops::*;
Expand Down Expand Up @@ -135,46 +135,13 @@ where [(); N + 1]:
pub fn trunk_add(self, rhs: Self) -> Self {
(self + rhs).resize()
}

/// Sign extends `U<N>` to `U<M>`.
pub fn sext<const M: usize>(self) -> U<M>
where
[(); (M - N) * 1]:,
[(); M * N]:,
[(); N + (M - N)]:,
{
if M >= N {
let msb_arr: Array<bool, { M - N }> = self.clip_const::<1>(N - 1).repeat::<{ M - N }>().concat().resize();
self.append(msb_arr).resize::<M>()
} else {
panic!("M should be larger than N")
}
}
}

impl<const N: usize> U<N> {
/// Returns the maximum value of an `N` bit unsigned value. (i.e., 2^`N` - 1)
pub fn unsigned_max() -> U<N> {
true.repeat::<N>()
}

/// Returns the maximum value of an `N` bit signed value. (i.e., 2^(`N` - 1) - 1)
pub fn signed_max() -> U<N>
where
[(); N - 1]:,
[(); (N - 1) + 1]:,
{
Self::unsigned_max().clip_const::<{ N - 1 }>(0).append(U::<1>::from(0)).resize::<N>()
}

/// Returns the minimum value of an `N` bit unsigned value. (i.e., -2^(`N` - 1))
pub fn signed_min() -> U<N>
where
[(); N - 1]:,
[(); (N - 1) + 1]:,
{
U::<{ N - 1 }>::from(0).append(U::<1>::from(1)).resize::<N>()
}
}

impl<const N: usize> Sub<U<N>> for U<N> {
Expand Down
Loading

0 comments on commit 4681dc1

Please sign in to comment.