From 880bdcde4b6514f99ae5d056bb3741d76d3b4aa1 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 28 Jan 2025 22:39:30 +0000 Subject: [PATCH] [ESI][BSP] Adding byte enables to cosim hostmem Writes can happen at less than the granularity of the width of the upstream port, so we must support writing part of the data width. --- frontends/PyCDE/src/pycde/bsp/common.py | 33 +++++++++++++++---- .../ESI/runtime/cpp/include/esi/Common.h | 2 +- lib/Dialect/ESI/runtime/cpp/lib/Common.cpp | 2 +- .../ESI/runtime/cpp/lib/backends/Cosim.cpp | 10 ++++-- .../ESI/runtime/cpp/tools/esitester.cpp | 20 ++++++++--- 5 files changed, 52 insertions(+), 15 deletions(-) diff --git a/frontends/PyCDE/src/pycde/bsp/common.py b/frontends/PyCDE/src/pycde/bsp/common.py index 056890796433..4d704e047df3 100644 --- a/frontends/PyCDE/src/pycde/bsp/common.py +++ b/frontends/PyCDE/src/pycde/bsp/common.py @@ -6,8 +6,8 @@ from math import ceil from ..common import Clock, Input, InputChannel, Output, OutputChannel, Reset -from ..constructs import (AssignableSignal, ControlReg, Counter, NamedWire, Reg, - Wire) +from ..constructs import (AssignableSignal, ControlReg, Counter, Mux, NamedWire, + Reg, Wire) from .. import esi from ..module import Module, generator, modparams from ..signals import BitsSignal, BundleSignal, ChannelSignal @@ -458,6 +458,8 @@ def TaggedWriteGearbox(input_bitwidth: int, if output_bitwidth % 8 != 0: raise ValueError("Output bitwidth must be a multiple of 8.") + if input_bitwidth % 8 != 0: + raise ValueError("Input bitwidth must be a multiple of 8.") class TaggedWriteGearboxImpl(Module): clk = Clock() @@ -473,8 +475,11 @@ class TaggedWriteGearboxImpl(Module): ("address", UInt(64)), ("tag", esi.HostMem.TagType), ("data", Bits(output_bitwidth)), + ("valid_bytes", Bits(8)), ])) + num_chunks = ceil(input_bitwidth / output_bitwidth) + @generator def build(ports): upstream_ready = Wire(Bits(1)) @@ -482,6 +487,8 @@ def build(ports): client_tag_and_data, client_valid = ports.in_.unwrap(ready_for_client) client_data = client_tag_and_data.data client_xact = ready_for_client & client_valid + input_bitwidth_bytes = input_bitwidth // 8 + output_bitwidth_bytes = output_bitwidth // 8 # Determine if gearboxing is necessary and whether it needs to be # gearboxed up or just sliced down. @@ -491,19 +498,23 @@ def build(ports): ready_for_client.assign(upstream_ready) tag = client_tag_and_data.tag address = client_tag_and_data.address + valid_bytes = Bits(8)(input_bitwidth_bytes) elif output_bitwidth > input_bitwidth: upstream_data_bits = client_data.as_bits(output_bitwidth) upstream_valid = client_valid ready_for_client.assign(upstream_ready) tag = client_tag_and_data.tag address = client_tag_and_data.address + valid_bytes = Bits(8)(input_bitwidth_bytes) else: # Create registers equal to the number of upstream transactions needed # to complete the transmission. - num_chunks = ceil(input_bitwidth / output_bitwidth) + num_chunks = TaggedWriteGearboxImpl.num_chunks num_chunks_idx_bitwidth = clog2(num_chunks) - padding = Bits(output_bitwidth - (input_bitwidth % output_bitwidth))(0) - client_data_padded = BitsSignal.concat([padding, client_data]) + padding_numbits = output_bitwidth - (input_bitwidth % output_bitwidth) + assert padding_numbits % 8 == 0, "Padding must be a multiple of 8." + client_data_padded = BitsSignal.concat( + [Bits(padding_numbits)(0), client_data]) chunks = [ client_data_padded[i * output_bitwidth:(i + 1) * output_bitwidth] for i in range(num_chunks) @@ -537,12 +548,16 @@ def build(ports): name="address_reg") address = (addr_reg + counter_bytes).as_uint(64) tag = tag_reg + valid_bytes = Mux(counter.out == (num_chunks - 1), + Bits(8)(output_bitwidth_bytes), + Bits(8)(padding_numbits // 8)) upstream_channel, upstrm_ready_sig = TaggedWriteGearboxImpl.out.type.wrap( { "address": address, "tag": tag, "data": upstream_data_bits, + "valid_bytes": valid_bytes }, upstream_valid) upstream_ready.assign(upstrm_ready_sig) ports.out = upstream_channel @@ -584,7 +599,8 @@ def build(ports): { "address": 0, "tag": 0, - "data": 0 + "data": 0, + "valid_bytes": 0, }, 0) write_bundle, _ = hostmem_module.write.type.pack(req=req) ports.upstream = write_bundle @@ -633,6 +649,7 @@ def build(ports): "address": m.address, "tag": idx, "data": m.data, + "valid_bytes": m.valid_bytes }))) # Set the port for the client request. setattr(ports, HostMemWriteProcessorImpl.reqPortMap[req], bundle_sig) @@ -670,10 +687,14 @@ class ChannelHostMemImpl(esi.ServiceImplementation): ("data", Bits(read_width)), ])), ])) + + if write_width % 8 != 0: + raise ValueError("Write width must be a multiple of 8.") UpstreamWriteReq = StructType([ ("address", UInt(64)), ("tag", UInt(8)), ("data", Bits(write_width)), + ("valid_bytes", Bits(8)), ]) write = Output( Bundle([ diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h index f7c82901b1f2..a4078ba8b7a9 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h @@ -147,7 +147,7 @@ std::ostream &operator<<(std::ostream &, const esi::AppID &); //===----------------------------------------------------------------------===// namespace esi { -std::string toHex(uint32_t val); +std::string toHex(uint64_t val); } // namespace esi #endif // ESI_COMMON_H diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Common.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Common.cpp index 50b4308cff7b..0a95890804fe 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Common.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Common.cpp @@ -34,7 +34,7 @@ std::string MessageData::toHex() const { return ss.str(); } -std::string esi::toHex(uint32_t val) { +std::string esi::toHex(uint64_t val) { std::ostringstream ss; ss << std::hex << val; return ss.str(); diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp index c99f95e419ba..88e83674ba33 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp @@ -428,6 +428,7 @@ struct HostMemReadResp { }; struct HostMemWriteReq { + uint8_t valid_bytes; uint64_t data; uint8_t tag; uint64_t address; @@ -540,10 +541,13 @@ class CosimHostMem : public HostMem { std::unique_ptr> &details) { subsystem = "HostMem"; msg = "Write request: addr=0x" + toHex(req->address) + " data=0x" + - toHex(req->data) + " tag=" + std::to_string(req->tag); + toHex(req->data) + + " valid_bytes=" + std::to_string(req->valid_bytes) + + " tag=" + std::to_string(req->tag); }); - uint64_t *dataPtr = reinterpret_cast(req->address); - *dataPtr = req->data; + uint8_t *dataPtr = reinterpret_cast(req->address); + for (uint8_t i = 0; i < req->valid_bytes; ++i) + dataPtr[i] = (req->data >> (i * 8)) & 0xFF; HostMemWriteResp resp = req->tag; return MessageData::from(resp); } diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp index d0fac38d1672..73b0eec3c390 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp @@ -151,11 +151,16 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion ®ion, // Initiate a test write. // TODO: remove the width == 96 once multiplexing support is added. if (write) { - auto check = [&]() { + assert(width % 8 == 0); + auto check = [&](bool print) { region.flush(); - for (size_t i = 0, e = (width + 63) / 64; i < e; ++i) + for (size_t i = 0, e = (width + 63) / 64; i < e; ++i) { + if (print) + std::cout << "dataPtr[" << i << "] = 0x" << esi::toHex(dataPtr[i]) + << std::endl; if (dataPtr[i] == 0xFFFFFFFFFFFFFFFFull) return false; + } return true; }; @@ -179,12 +184,19 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion ®ion, writeMem->write(0, reinterpret_cast(devicePtr)); // Wait for the accelerator to write. Timeout and fail after 10ms. for (int i = 0; i < 100; ++i) { - if (check()) + if (check(false)) break; std::this_thread::sleep_for(std::chrono::microseconds(100)); } - if (!check()) + if (!check(true)) throw std::runtime_error("DMA write test failed"); + + // Check that the accelerator didn't write too far. + size_t widthInBytes = width / 8; + uint8_t *dataPtr8 = reinterpret_cast(region.getPtr()); + for (size_t i = widthInBytes, e = (widthInBytes + 7) / 8; i < e; ++i) + if (dataPtr8[i] != 0xFF) + throw std::runtime_error("DMA write test failed -- write went too far"); } }