diff --git a/frontends/PyCDE/src/pycde/bsp/common.py b/frontends/PyCDE/src/pycde/bsp/common.py index 056890796433..4d704e047df3 100644 --- a/frontends/PyCDE/src/pycde/bsp/common.py +++ b/frontends/PyCDE/src/pycde/bsp/common.py @@ -6,8 +6,8 @@ from math import ceil from ..common import Clock, Input, InputChannel, Output, OutputChannel, Reset -from ..constructs import (AssignableSignal, ControlReg, Counter, NamedWire, Reg, - Wire) +from ..constructs import (AssignableSignal, ControlReg, Counter, Mux, NamedWire, + Reg, Wire) from .. import esi from ..module import Module, generator, modparams from ..signals import BitsSignal, BundleSignal, ChannelSignal @@ -458,6 +458,8 @@ def TaggedWriteGearbox(input_bitwidth: int, if output_bitwidth % 8 != 0: raise ValueError("Output bitwidth must be a multiple of 8.") + if input_bitwidth % 8 != 0: + raise ValueError("Input bitwidth must be a multiple of 8.") class TaggedWriteGearboxImpl(Module): clk = Clock() @@ -473,8 +475,11 @@ class TaggedWriteGearboxImpl(Module): ("address", UInt(64)), ("tag", esi.HostMem.TagType), ("data", Bits(output_bitwidth)), + ("valid_bytes", Bits(8)), ])) + num_chunks = ceil(input_bitwidth / output_bitwidth) + @generator def build(ports): upstream_ready = Wire(Bits(1)) @@ -482,6 +487,8 @@ def build(ports): client_tag_and_data, client_valid = ports.in_.unwrap(ready_for_client) client_data = client_tag_and_data.data client_xact = ready_for_client & client_valid + input_bitwidth_bytes = input_bitwidth // 8 + output_bitwidth_bytes = output_bitwidth // 8 # Determine if gearboxing is necessary and whether it needs to be # gearboxed up or just sliced down. @@ -491,19 +498,23 @@ def build(ports): ready_for_client.assign(upstream_ready) tag = client_tag_and_data.tag address = client_tag_and_data.address + valid_bytes = Bits(8)(input_bitwidth_bytes) elif output_bitwidth > input_bitwidth: upstream_data_bits = client_data.as_bits(output_bitwidth) upstream_valid = client_valid ready_for_client.assign(upstream_ready) tag = client_tag_and_data.tag address = client_tag_and_data.address + valid_bytes = Bits(8)(input_bitwidth_bytes) else: # Create registers equal to the number of upstream transactions needed # to complete the transmission. - num_chunks = ceil(input_bitwidth / output_bitwidth) + num_chunks = TaggedWriteGearboxImpl.num_chunks num_chunks_idx_bitwidth = clog2(num_chunks) - padding = Bits(output_bitwidth - (input_bitwidth % output_bitwidth))(0) - client_data_padded = BitsSignal.concat([padding, client_data]) + padding_numbits = output_bitwidth - (input_bitwidth % output_bitwidth) + assert padding_numbits % 8 == 0, "Padding must be a multiple of 8." + client_data_padded = BitsSignal.concat( + [Bits(padding_numbits)(0), client_data]) chunks = [ client_data_padded[i * output_bitwidth:(i + 1) * output_bitwidth] for i in range(num_chunks) @@ -537,12 +548,16 @@ def build(ports): name="address_reg") address = (addr_reg + counter_bytes).as_uint(64) tag = tag_reg + valid_bytes = Mux(counter.out == (num_chunks - 1), + Bits(8)(output_bitwidth_bytes), + Bits(8)(padding_numbits // 8)) upstream_channel, upstrm_ready_sig = TaggedWriteGearboxImpl.out.type.wrap( { "address": address, "tag": tag, "data": upstream_data_bits, + "valid_bytes": valid_bytes }, upstream_valid) upstream_ready.assign(upstrm_ready_sig) ports.out = upstream_channel @@ -584,7 +599,8 @@ def build(ports): { "address": 0, "tag": 0, - "data": 0 + "data": 0, + "valid_bytes": 0, }, 0) write_bundle, _ = hostmem_module.write.type.pack(req=req) ports.upstream = write_bundle @@ -633,6 +649,7 @@ def build(ports): "address": m.address, "tag": idx, "data": m.data, + "valid_bytes": m.valid_bytes }))) # Set the port for the client request. setattr(ports, HostMemWriteProcessorImpl.reqPortMap[req], bundle_sig) @@ -670,10 +687,14 @@ class ChannelHostMemImpl(esi.ServiceImplementation): ("data", Bits(read_width)), ])), ])) + + if write_width % 8 != 0: + raise ValueError("Write width must be a multiple of 8.") UpstreamWriteReq = StructType([ ("address", UInt(64)), ("tag", UInt(8)), ("data", Bits(write_width)), + ("valid_bytes", Bits(8)), ]) write = Output( Bundle([ diff --git a/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h b/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h index f7c82901b1f2..a4078ba8b7a9 100644 --- a/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h +++ b/lib/Dialect/ESI/runtime/cpp/include/esi/Common.h @@ -147,7 +147,7 @@ std::ostream &operator<<(std::ostream &, const esi::AppID &); //===----------------------------------------------------------------------===// namespace esi { -std::string toHex(uint32_t val); +std::string toHex(uint64_t val); } // namespace esi #endif // ESI_COMMON_H diff --git a/lib/Dialect/ESI/runtime/cpp/lib/Common.cpp b/lib/Dialect/ESI/runtime/cpp/lib/Common.cpp index 50b4308cff7b..0a95890804fe 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/Common.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/Common.cpp @@ -34,7 +34,7 @@ std::string MessageData::toHex() const { return ss.str(); } -std::string esi::toHex(uint32_t val) { +std::string esi::toHex(uint64_t val) { std::ostringstream ss; ss << std::hex << val; return ss.str(); diff --git a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp index c99f95e419ba..88e83674ba33 100644 --- a/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp +++ b/lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp @@ -428,6 +428,7 @@ struct HostMemReadResp { }; struct HostMemWriteReq { + uint8_t valid_bytes; uint64_t data; uint8_t tag; uint64_t address; @@ -540,10 +541,13 @@ class CosimHostMem : public HostMem { std::unique_ptr> &details) { subsystem = "HostMem"; msg = "Write request: addr=0x" + toHex(req->address) + " data=0x" + - toHex(req->data) + " tag=" + std::to_string(req->tag); + toHex(req->data) + + " valid_bytes=" + std::to_string(req->valid_bytes) + + " tag=" + std::to_string(req->tag); }); - uint64_t *dataPtr = reinterpret_cast(req->address); - *dataPtr = req->data; + uint8_t *dataPtr = reinterpret_cast(req->address); + for (uint8_t i = 0; i < req->valid_bytes; ++i) + dataPtr[i] = (req->data >> (i * 8)) & 0xFF; HostMemWriteResp resp = req->tag; return MessageData::from(resp); } diff --git a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp index d0fac38d1672..73b0eec3c390 100644 --- a/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp +++ b/lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp @@ -151,11 +151,16 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion ®ion, // Initiate a test write. // TODO: remove the width == 96 once multiplexing support is added. if (write) { - auto check = [&]() { + assert(width % 8 == 0); + auto check = [&](bool print) { region.flush(); - for (size_t i = 0, e = (width + 63) / 64; i < e; ++i) + for (size_t i = 0, e = (width + 63) / 64; i < e; ++i) { + if (print) + std::cout << "dataPtr[" << i << "] = 0x" << esi::toHex(dataPtr[i]) + << std::endl; if (dataPtr[i] == 0xFFFFFFFFFFFFFFFFull) return false; + } return true; }; @@ -179,12 +184,19 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion ®ion, writeMem->write(0, reinterpret_cast(devicePtr)); // Wait for the accelerator to write. Timeout and fail after 10ms. for (int i = 0; i < 100; ++i) { - if (check()) + if (check(false)) break; std::this_thread::sleep_for(std::chrono::microseconds(100)); } - if (!check()) + if (!check(true)) throw std::runtime_error("DMA write test failed"); + + // Check that the accelerator didn't write too far. + size_t widthInBytes = width / 8; + uint8_t *dataPtr8 = reinterpret_cast(region.getPtr()); + for (size_t i = widthInBytes, e = (widthInBytes + 7) / 8; i < e; ++i) + if (dataPtr8[i] != 0xFF) + throw std::runtime_error("DMA write test failed -- write went too far"); } }