Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ESI][BSP] Adding byte enables to cosim hostmem #8138

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from math import ceil

from ..common import Clock, Input, InputChannel, Output, OutputChannel, Reset
from ..constructs import (AssignableSignal, ControlReg, Counter, NamedWire, Reg,
Wire)
from ..constructs import (AssignableSignal, ControlReg, Counter, Mux, NamedWire,
Reg, Wire)
from .. import esi
from ..module import Module, generator, modparams
from ..signals import BitsSignal, BundleSignal, ChannelSignal
Expand Down Expand Up @@ -458,6 +458,8 @@ def TaggedWriteGearbox(input_bitwidth: int,

if output_bitwidth % 8 != 0:
raise ValueError("Output bitwidth must be a multiple of 8.")
if input_bitwidth % 8 != 0:
raise ValueError("Input bitwidth must be a multiple of 8.")

class TaggedWriteGearboxImpl(Module):
clk = Clock()
Expand All @@ -473,15 +475,20 @@ class TaggedWriteGearboxImpl(Module):
("address", UInt(64)),
("tag", esi.HostMem.TagType),
("data", Bits(output_bitwidth)),
("valid_bytes", Bits(8)),
]))

num_chunks = ceil(input_bitwidth / output_bitwidth)

@generator
def build(ports):
upstream_ready = Wire(Bits(1))
ready_for_client = Wire(Bits(1))
client_tag_and_data, client_valid = ports.in_.unwrap(ready_for_client)
client_data = client_tag_and_data.data
client_xact = ready_for_client & client_valid
input_bitwidth_bytes = input_bitwidth // 8
output_bitwidth_bytes = output_bitwidth // 8

# Determine if gearboxing is necessary and whether it needs to be
# gearboxed up or just sliced down.
Expand All @@ -491,19 +498,23 @@ def build(ports):
ready_for_client.assign(upstream_ready)
tag = client_tag_and_data.tag
address = client_tag_and_data.address
valid_bytes = Bits(8)(input_bitwidth_bytes)
elif output_bitwidth > input_bitwidth:
upstream_data_bits = client_data.as_bits(output_bitwidth)
upstream_valid = client_valid
ready_for_client.assign(upstream_ready)
tag = client_tag_and_data.tag
address = client_tag_and_data.address
valid_bytes = Bits(8)(input_bitwidth_bytes)
else:
# Create registers equal to the number of upstream transactions needed
# to complete the transmission.
num_chunks = ceil(input_bitwidth / output_bitwidth)
num_chunks = TaggedWriteGearboxImpl.num_chunks
num_chunks_idx_bitwidth = clog2(num_chunks)
padding = Bits(output_bitwidth - (input_bitwidth % output_bitwidth))(0)
client_data_padded = BitsSignal.concat([padding, client_data])
padding_numbits = output_bitwidth - (input_bitwidth % output_bitwidth)
assert padding_numbits % 8 == 0, "Padding must be a multiple of 8."
client_data_padded = BitsSignal.concat(
[Bits(padding_numbits)(0), client_data])
chunks = [
client_data_padded[i * output_bitwidth:(i + 1) * output_bitwidth]
for i in range(num_chunks)
Expand Down Expand Up @@ -537,12 +548,16 @@ def build(ports):
name="address_reg")
address = (addr_reg + counter_bytes).as_uint(64)
tag = tag_reg
valid_bytes = Mux(counter.out == (num_chunks - 1),
Bits(8)(output_bitwidth_bytes),
Bits(8)(padding_numbits // 8))

upstream_channel, upstrm_ready_sig = TaggedWriteGearboxImpl.out.type.wrap(
{
"address": address,
"tag": tag,
"data": upstream_data_bits,
"valid_bytes": valid_bytes
}, upstream_valid)
upstream_ready.assign(upstrm_ready_sig)
ports.out = upstream_channel
Expand Down Expand Up @@ -584,7 +599,8 @@ def build(ports):
{
"address": 0,
"tag": 0,
"data": 0
"data": 0,
"valid_bytes": 0,
}, 0)
write_bundle, _ = hostmem_module.write.type.pack(req=req)
ports.upstream = write_bundle
Expand Down Expand Up @@ -633,6 +649,7 @@ def build(ports):
"address": m.address,
"tag": idx,
"data": m.data,
"valid_bytes": m.valid_bytes
})))
# Set the port for the client request.
setattr(ports, HostMemWriteProcessorImpl.reqPortMap[req], bundle_sig)
Expand Down Expand Up @@ -670,10 +687,14 @@ class ChannelHostMemImpl(esi.ServiceImplementation):
("data", Bits(read_width)),
])),
]))

if write_width % 8 != 0:
raise ValueError("Write width must be a multiple of 8.")
UpstreamWriteReq = StructType([
("address", UInt(64)),
("tag", UInt(8)),
("data", Bits(write_width)),
("valid_bytes", Bits(8)),
])
write = Output(
Bundle([
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/ESI/runtime/cpp/include/esi/Common.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ std::ostream &operator<<(std::ostream &, const esi::AppID &);
//===----------------------------------------------------------------------===//

namespace esi {
std::string toHex(uint32_t val);
std::string toHex(uint64_t val);
} // namespace esi

#endif // ESI_COMMON_H
2 changes: 1 addition & 1 deletion lib/Dialect/ESI/runtime/cpp/lib/Common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ std::string MessageData::toHex() const {
return ss.str();
}

std::string esi::toHex(uint32_t val) {
std::string esi::toHex(uint64_t val) {
std::ostringstream ss;
ss << std::hex << val;
return ss.str();
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ struct HostMemReadResp {
};

struct HostMemWriteReq {
uint8_t valid_bytes;
uint64_t data;
uint8_t tag;
uint64_t address;
Expand Down Expand Up @@ -540,10 +541,13 @@ class CosimHostMem : public HostMem {
std::unique_ptr<std::map<std::string, std::any>> &details) {
subsystem = "HostMem";
msg = "Write request: addr=0x" + toHex(req->address) + " data=0x" +
toHex(req->data) + " tag=" + std::to_string(req->tag);
toHex(req->data) +
" valid_bytes=" + std::to_string(req->valid_bytes) +
" tag=" + std::to_string(req->tag);
});
uint64_t *dataPtr = reinterpret_cast<uint64_t *>(req->address);
*dataPtr = req->data;
uint8_t *dataPtr = reinterpret_cast<uint8_t *>(req->address);
for (uint8_t i = 0; i < req->valid_bytes; ++i)
dataPtr[i] = (req->data >> (i * 8)) & 0xFF;
HostMemWriteResp resp = req->tag;
return MessageData::from(resp);
}
Expand Down
20 changes: 16 additions & 4 deletions lib/Dialect/ESI/runtime/cpp/tools/esitester.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,16 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion &region,
// Initiate a test write.
// TODO: remove the width == 96 once multiplexing support is added.
if (write) {
auto check = [&]() {
assert(width % 8 == 0);
auto check = [&](bool print) {
region.flush();
for (size_t i = 0, e = (width + 63) / 64; i < e; ++i)
for (size_t i = 0, e = (width + 63) / 64; i < e; ++i) {
if (print)
std::cout << "dataPtr[" << i << "] = 0x" << esi::toHex(dataPtr[i])
<< std::endl;
if (dataPtr[i] == 0xFFFFFFFFFFFFFFFFull)
return false;
}
return true;
};

Expand All @@ -179,12 +184,19 @@ void dmaTest(Accelerator *acc, esi::services::HostMem::HostMemRegion &region,
writeMem->write(0, reinterpret_cast<uint64_t>(devicePtr));
// Wait for the accelerator to write. Timeout and fail after 10ms.
for (int i = 0; i < 100; ++i) {
if (check())
if (check(false))
break;
std::this_thread::sleep_for(std::chrono::microseconds(100));
}
if (!check())
if (!check(true))
throw std::runtime_error("DMA write test failed");

// Check that the accelerator didn't write too far.
size_t widthInBytes = width / 8;
uint8_t *dataPtr8 = reinterpret_cast<uint8_t *>(region.getPtr());
for (size_t i = widthInBytes, e = (widthInBytes + 7) / 8; i < e; ++i)
if (dataPtr8[i] != 0xFF)
throw std::runtime_error("DMA write test failed -- write went too far");
}
}

Expand Down