Skip to content

Commit

Permalink
[ESI][Runtime] Fix cosim disconnect bug and add logging
Browse files Browse the repository at this point in the history
- Mitigate gRPC memory issue on disconnects. I still don't understand
how to disconnect from a channel properly, so don't ever disconnect the
underlying gRPC channel.
- Adds debug MMIO logging.
  • Loading branch information
teqdruid committed Feb 27, 2025
1 parent 65a32af commit 6971f69
Showing 1 changed file with 46 additions and 13 deletions.
59 changes: 46 additions & 13 deletions lib/Dialect/ESI/runtime/cpp/lib/backends/Cosim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,11 @@ class ReadCosimChannelPort
: public ReadChannelPort,
public grpc::ClientReadReactor<esi::cosim::Message> {
public:
ReadCosimChannelPort(ChannelServer::Stub *rpcClient, const ChannelDesc &desc,
ReadCosimChannelPort(AcceleratorConnection &conn,
ChannelServer::Stub *rpcClient, const ChannelDesc &desc,
const Type *type, std::string name)
: ReadChannelPort(type), rpcClient(rpcClient), desc(desc), name(name),
context(nullptr) {}
: ReadChannelPort(type), conn(conn), rpcClient(rpcClient), desc(desc),
name(name), context(nullptr) {}
virtual ~ReadCosimChannelPort() { disconnect(); }

void connectImpl(std::optional<unsigned> bufferSize) override {
Expand All @@ -244,8 +245,10 @@ class ReadCosimChannelPort
assert(desc.name() == name);

// Initiate a stream of messages from the server.
context = std::make_unique<ClientContext>();
rpcClient->async()->ConnectToClientChannel(context.get(), &desc, this);
if (context)
return;
context = new ClientContext();
rpcClient->async()->ConnectToClientChannel(context, &desc, this);
StartCall();
StartRead(&incomingMessage);
}
Expand All @@ -272,21 +275,25 @@ class ReadCosimChannelPort

/// Disconnect this channel from the server.
void disconnect() override {
Logger &logger = conn.getLogger();
logger.debug("cosim_read", "Disconnecting channel " + name);
if (!context)
return;
context->TryCancel();
context.reset();
// Don't delete the context since gRPC still hold a reference to it.
// TODO: figure out how to delete it.
ReadChannelPort::disconnect();
}

protected:
AcceleratorConnection &conn;
ChannelServer::Stub *rpcClient;
/// The channel description as provided by the server.
ChannelDesc desc;
/// The name of the channel from the manifest.
std::string name;

std::unique_ptr<ClientContext> context;
ClientContext *context;
/// Storage location for the incoming message.
esi::cosim::Message incomingMessage;
};
Expand Down Expand Up @@ -335,7 +342,7 @@ class CosimMMIO : public MMIO {
cmdArgPort = std::make_unique<WriteCosimChannelPort>(
rpcClient->stub.get(), cmdArg, cmdType, "__cosim_mmio_read_write.arg");
cmdRespPort = std::make_unique<ReadCosimChannelPort>(
rpcClient->stub.get(), cmdResp, i64Type,
conn, rpcClient->stub.get(), cmdResp, i64Type,
"__cosim_mmio_read_write.result");
auto *bundleType = new BundleType(
"cosimMMIO", {{"arg", BundleType::Direction::To, cmdType},
Expand All @@ -359,10 +366,24 @@ class CosimMMIO : public MMIO {
auto arg = MessageData::from(cmd);
std::future<MessageData> result = cmdMMIO->call(arg);
result.wait();
return *result.get().as<uint64_t>();
uint64_t ret = *result.get().as<uint64_t>();
conn.getLogger().debug(
[addr, ret](std::string &subsystem, std::string &msg,
std::unique_ptr<std::map<std::string, std::any>> &details) {
subsystem = "cosim_mmio";
msg = "MMIO[0x" + toHex(addr) + "] = 0x" + toHex(ret);
});
return ret;
}

void write(uint32_t addr, uint64_t data) override {
conn.getLogger().debug(
[addr,
data](std::string &subsystem, std::string &msg,
std::unique_ptr<std::map<std::string, std::any>> &details) {
subsystem = "cosim_mmio";
msg = "MMIO[0x" + toHex(addr) + "] <- 0x" + toHex(data);
});
MMIOCmd cmd{.data = data, .offset = addr, .write = true};
auto arg = MessageData::from(cmd);
std::future<MessageData> result = cmdMMIO->call(arg);
Expand Down Expand Up @@ -415,6 +436,9 @@ class CosimHostMem : public HostMem {
// We have to locate the channels ourselves since this service might be used
// to retrieve the manifest.

if (writeRespPort)
return;

// TODO: The types here are WRONG. They need to be wrapped in Channels! Fix
// this in a subsequent PR.

Expand All @@ -440,7 +464,7 @@ class CosimHostMem : public HostMem {
rpcClient->stub.get(), readResp, readRespType,
"__cosim_hostmem_read_resp.data");
readReqPort = std::make_unique<ReadCosimChannelPort>(
rpcClient->stub.get(), readArg, readReqType,
conn, rpcClient->stub.get(), readArg, readReqType,
"__cosim_hostmem_read_req.data");
readReqPort->connect(
[this](const MessageData &req) { return serviceRead(req); });
Expand All @@ -464,7 +488,7 @@ class CosimHostMem : public HostMem {
rpcClient->stub.get(), writeResp, writeRespType,
"__cosim_hostmem_write.result");
writeReqPort = std::make_unique<ReadCosimChannelPort>(
rpcClient->stub.get(), writeArg, writeReqType,
conn, rpcClient->stub.get(), writeArg, writeReqType,
"__cosim_hostmem_write.arg");
auto *bundleType = new BundleType(
"cosimHostMem",
Expand Down Expand Up @@ -528,6 +552,7 @@ class CosimHostMem : public HostMem {
struct CosimHostMemRegion : public HostMemRegion {
CosimHostMemRegion(std::size_t size) {
ptr = malloc(size);
memset(ptr, 0xFF, size);
this->size = size;
}
virtual ~CosimHostMemRegion() { free(ptr); }
Expand All @@ -541,7 +566,15 @@ class CosimHostMem : public HostMem {

virtual std::unique_ptr<HostMemRegion>
allocate(std::size_t size, HostMem::Options opts) const override {
return std::unique_ptr<HostMemRegion>(new CosimHostMemRegion(size));
auto ret = std::unique_ptr<HostMemRegion>(new CosimHostMemRegion(size));
acc.getLogger().debug(
[&](std::string &subsystem, std::string &msg,
std::unique_ptr<std::map<std::string, std::any>> &details) {
subsystem = "HostMem";
msg = "Allocated host memory region at 0x" + toHex(ret->getPtr()) +
" of size " + std::to_string(size);
});
return ret;
}
virtual bool mapMemory(void *ptr, std::size_t size,
HostMem::Options opts) const override {
Expand Down Expand Up @@ -634,7 +667,7 @@ CosimEngine::createPort(AppIDPath idPath, const std::string &channelName,
conn.rpcClient->stub.get(), chDesc, type, fullChannelName);
else
port = std::make_unique<ReadCosimChannelPort>(
conn.rpcClient->stub.get(), chDesc, type, fullChannelName);
conn, conn.rpcClient->stub.get(), chDesc, type, fullChannelName);
return port;
}

Expand Down

0 comments on commit 6971f69

Please sign in to comment.