From 717376366a45601741872367cb7e4d6753bae810 Mon Sep 17 00:00:00 2001 From: Minseong Jang Date: Fri, 1 Nov 2024 14:36:24 +0900 Subject: [PATCH] Minor fix --- .../src/gemmini/execute/systolic_array/pe.rs | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs b/hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs index 8283aeb..982d394 100644 --- a/hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs +++ b/hazardflow-designs/src/gemmini/execute/systolic_array/pe.rs @@ -4,6 +4,9 @@ use super::*; +/// Bit width of the register type. +const ACC_BITS: usize = 32; + /// PE row data signals. #[derive(Debug, Clone, Copy)] pub struct PeRowData { @@ -51,7 +54,7 @@ pub struct PeControl { /// The number of bits by which the accumulated result of matrix multiplication is right-shifted when leaving the /// systolic array, used to scale down the result. - pub shift: U<5>, + pub shift: U<{ clog2(ACC_BITS) }>, } /// Represents the dataflow. @@ -89,10 +92,10 @@ pub enum Propagate { #[derive(Debug, Default, Clone, Copy)] pub struct PeS { /// Register 1. - pub reg1: S<32>, + pub reg1: S, /// Register 2. - pub reg2: S<32>, + pub reg2: S, /// The propagate value comes from the previous input. /// @@ -102,7 +105,7 @@ pub struct PeS { impl PeS { /// Creates a new PE state. - pub fn new(reg1: S<32>, reg2: S<32>, propagate: Propagate) -> Self { + pub fn new(reg1: S, reg2: S, propagate: Propagate) -> Self { Self { reg1, reg2, propagate } } @@ -113,7 +116,10 @@ impl PeS { /// - `preload`: Bias value for the next operation. /// - `partial_sum`: MAC result of the current operation. /// - `propagate`: Propagate value. - pub fn new_os(preload: S<32>, partial_sum: S<32>, propagate: Propagate) -> Self { + pub fn new_os(preload: S, partial_sum: S, propagate: Propagate) -> Self { + let preload = preload.sext::(); + let partial_sum = partial_sum.sext::(); + match propagate { Propagate::Reg1 => PeS::new(preload, partial_sum, propagate), Propagate::Reg2 => PeS::new(partial_sum, preload, propagate), @@ -127,7 +133,10 @@ impl PeS { /// - `preload`: Weight value for the next operation. /// - `weight`: Weight value for the current operation. /// - `propagate`: Propagate value. - pub fn new_ws(preload: S<32>, weight: S<32>, propagate: Propagate) -> Self { + pub fn new_ws(preload: S, weight: S, propagate: Propagate) -> Self { + let preload = preload.sext::(); + let weight = weight.sext::(); + match propagate { Propagate::Reg1 => PeS::new(preload, weight, propagate), Propagate::Reg2 => PeS::new(weight, preload, propagate), @@ -138,16 +147,16 @@ impl PeS { /// MAC unit (computes `a * b + c`). /// /// It preserves the signedness of operands. -fn mac(a: S<8>, b: S<8>, c: S<32>) -> S { +fn mac(a: S, b: S, c: S) -> S { super::arithmetic::mac(a, b, c) } /// Performs right-shift (`val >> shamt`) and then clips to `OUTPUT_BITS`. /// /// It preserves the signedness of `val`. -fn shift_and_clip(val: S<32>, shamt: U<5>) -> S { +fn shift_and_clip(val: S, shamt: U<{ clog2(ACC_BITS) }>) -> S { let shifted = rounding_shift(val, shamt); - super::arithmetic::clip_with_saturation::<32, OUTPUT_BITS>(shifted) + super::arithmetic::clip_with_saturation::(shifted) } /// PE.