Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyCDE][ESI] Manifest: add record about client #7299

Merged
merged 1 commit into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 38 additions & 15 deletions frontends/PyCDE/src/pycde/esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .common import (AppID, Input, Output, _PyProxy, PortError)
from .module import Generator, Module, ModuleLikeBuilderBase, PortProxyBase
from .signals import BundleSignal, ChannelSignal, Signal, _FromCirctValue
from .support import get_user_loc
from .system import System
from .types import (Bits, Bundle, BundledChannel, ChannelDirection, Type, types,
_FromCirctType)
Expand Down Expand Up @@ -128,11 +129,35 @@ class _OutputBundleSetter:
have implemented for this request."""

def __init__(self, req: raw_esi.ServiceImplementConnReqOp,
rec: raw_esi.ServiceImplRecordOp,
old_value_to_replace: ir.OpResult):
self.req = req
self.rec = rec
self.type: Bundle = _FromCirctType(req.toClient.type)
self.client_name = req.relativeAppIDPath
self.port = hw.InnerRefAttr(req.servicePort).name.value
self._bundle_to_replace: Optional[ir.OpResult] = old_value_to_replace

def add_record(self, details: Dict[str, str]):
"""Add a record to the manifest for this client request. Generally used to
give the runtime necessary information about how to connect to the client
through the generated service. For instance, offsets into an MMIO space."""

ir_details: Dict[str, ir.StringAttr] = {}
for k, v in details.items():
ir_details[k] = ir.StringAttr.get(str(v))
with get_user_loc(), ir.InsertionPoint.at_block_begin(
self.rec.reqDetails.blocks[0]):
raw_esi.ServiceImplClientRecordOp(
self.req.relativeAppIDPath,
self.req.servicePort,
ir.TypeAttr.get(self.req.toClient.type),
ir_details,
)

@property
def client_name(self):
return self.req.relativeAppIDPath

def assign(self, new_value: ChannelSignal):
"""Assign the generated channel to this request."""
if self._bundle_to_replace is None:
Expand All @@ -150,8 +175,10 @@ class _ServiceGeneratorBundles:
for connecting up."""

def __init__(self, mod: ModuleLikeBuilderBase,
req: raw_esi.ServiceImplementReqOp):
req: raw_esi.ServiceImplementReqOp,
rec: raw_esi.ServiceImplRecordOp):
self._req = req
self._rec = rec
portReqsBlock = req.portReqs.blocks[0]

# Find the output channel requests and store the settable proxies.
Expand All @@ -161,25 +188,19 @@ def __init__(self, mod: ModuleLikeBuilderBase,
if isinstance(req, raw_esi.ServiceImplementConnReqOp)
]
self._output_reqs = [
_OutputBundleSetter(req, self._req.results[num_output_ports + idx])
_OutputBundleSetter(req, rec, self._req.results[num_output_ports + idx])
for idx, req in enumerate(to_client_reqs)
]
assert len(self._output_reqs) == len(req.results) - num_output_ports

@property
def reqs(self) -> List[NamedChannelValue]:
"""Get the list of incoming channels from the 'to server' connection
requests."""
return self._input_reqs

@property
def to_client_reqs(self) -> List[_OutputBundleSetter]:
return self._output_reqs

def check_unconnected_outputs(self):
for req in self._output_reqs:
if req._bundle_to_replace is not None:
name_str = ".".join(req.client_name)
name_str = str(req.client_name)
raise ValueError(f"{name_str} has not been connected.")


Expand All @@ -206,8 +227,8 @@ def instantiate(self, impl, inputs: Dict[str, Signal], appid: AppID):
impl_opts=opts,
loc=self.loc)

def generate_svc_impl(self,
serviceReq: raw_esi.ServiceImplementReqOp) -> bool:
def generate_svc_impl(self, serviceReq: raw_esi.ServiceImplementReqOp,
record_op: raw_esi.ServiceImplRecordOp) -> bool:
""""Generate the service inline and replace the `ServiceInstanceOp` which is
being implemented."""

Expand All @@ -217,7 +238,7 @@ def generate_svc_impl(self,
with self.GeneratorCtxt(self, ports, serviceReq, generator.loc):

# Run the generator.
bundles = _ServiceGeneratorBundles(self, serviceReq)
bundles = _ServiceGeneratorBundles(self, serviceReq, record_op)
rc = generator.gen_func(ports, bundles=bundles)
if rc is None:
rc = True
Expand Down Expand Up @@ -292,7 +313,8 @@ def register(self,
self._registry[name_attr] = (service_implementation, System.current())
return ir.DictAttr.get({"name": name_attr})

def _implement_service(self, req: ir.Operation):
def _implement_service(self, req: ir.Operation, decl: ir.Operation,
rec: ir.Operation):
"""This is the callback which the ESI connect-services pass calls. Dispatch
to the op-specified generator."""
assert isinstance(req.opview, raw_esi.ServiceImplementReqOp)
Expand All @@ -302,7 +324,8 @@ def _implement_service(self, req: ir.Operation):
return False
(impl, sys) = self._registry[impl_name]
with sys:
ret = impl._builder.generate_svc_impl(serviceReq=req.opview)
ret = impl._builder.generate_svc_impl(serviceReq=req.opview,
record_op=rec.opview)
# The service implementation generator could have instantiated new modules,
# so we need to generate them. Don't run the appID indexer since during a
# pass, the IR can be invalid and the indexers assumes it is valid.
Expand Down
1 change: 1 addition & 0 deletions frontends/PyCDE/test/test_esi_servicegens.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def generate(self, bundles: esi._ServiceGeneratorBundles):
assert len(
bundles.to_client_reqs) == 1, "Only one connection request supported"
bundle = bundles.to_client_reqs[0]
bundle.add_record({"foo": 5})
to_req_types = {}
for bundled_chan in bundle.type.channels:
if bundled_chan.direction == ChannelDirection.TO:
Expand Down
3 changes: 2 additions & 1 deletion include/circt-c/Dialect/ESI.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ circtESIBundleTypeGetChannel(MlirType bundle, size_t idx);
//===----------------------------------------------------------------------===//

typedef MlirLogicalResult (*CirctESIServiceGeneratorFunc)(
MlirOperation serviceImplementReqOp, MlirOperation declOp, void *userData);
MlirOperation serviceImplementReqOp, MlirOperation declOp,
MlirOperation recordOp, void *userData);
MLIR_CAPI_EXPORTED void circtESIRegisterGlobalServiceGenerator(
MlirStringRef impl_type, CirctESIServiceGeneratorFunc, void *userData);

Expand Down
20 changes: 15 additions & 5 deletions integration_test/Bindings/Python/dialects/esi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,24 @@
assert (bundle_type.resettable)
print()

# CHECK-LABEL: === testGen called with ops:
# CHECK-NEXT: [[R0:%.+]]:2 = esi.service.impl_req #esi.appid<"mstop"> svc @HostComms impl as "test"(%clk) : (i1) -> (i8, !esi.bundle<[!esi.channel<i8> to "recv"]>) {
# CHECK-NEXT: [[R2:%.+]] = esi.service.impl_req.req <@HostComms::@Recv>([#esi.appid<"loopback_tohw">]) : !esi.bundle<[!esi.channel<i8> to "recv"]>
# CHECK-NEXT: }
# CHECK-NEXT: esi.service.decl @HostComms {
# CHECK-NEXT: esi.service.port @Recv : !esi.bundle<[!esi.channel<i8> to "recv"]>
# CHECK-NEXT: }
# CHECK-NEXT: esi.manifest.service_impl #esi.appid<"mstop"> svc @HostComms by "test" with {}

# CHECK-LABEL: === testGen called with op:
# CHECK: [[R0:%.+]]:2 = esi.service.impl_req #esi.appid<"mstop"> svc @HostComms impl as "test"(%clk) : (i1) -> (i8, !esi.bundle<[!esi.channel<i8> to "recv"]>) {
# CHECK: [[R2:%.+]] = esi.service.impl_req.req <@HostComms::@Recv>([#esi.appid<"loopback_tohw">]) : !esi.bundle<[!esi.channel<i8> to "recv"]>
def testGen(reqOp: esi.ServiceImplementReqOp) -> bool:
print("=== testGen called with op:")

def testGen(reqOp: Operation, decl_op: Operation, rec_op: Operation) -> bool:
print("=== testGen called with ops:")
reqOp.print()
print()
decl_op.print()
print()
rec_op.print()
print()
return True


Expand Down
5 changes: 3 additions & 2 deletions lib/Bindings/Python/ESIModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ using namespace circt::esi;
// pointers since we also need to allocate memory for the string.
llvm::DenseMap<std::string *, PyObject *> serviceGenFuncLookup;
static MlirLogicalResult serviceGenFunc(MlirOperation reqOp,
MlirOperation declOp, void *userData) {
MlirOperation declOp,
MlirOperation recOp, void *userData) {
std::string *name = static_cast<std::string *>(userData);
py::handle genFunc(serviceGenFuncLookup[name]);
py::gil_scoped_acquire();
py::object rc = genFunc(reqOp);
py::object rc = genFunc(reqOp, declOp, recOp);
return rc.cast<bool>() ? mlirLogicalResultSuccess()
: mlirLogicalResultFailure();
}
Expand Down
9 changes: 5 additions & 4 deletions lib/CAPI/Dialect/ESI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,10 +86,11 @@ void circtESIRegisterGlobalServiceGenerator(
MlirStringRef impl_type, CirctESIServiceGeneratorFunc genFunc,
void *userData) {
ServiceGeneratorDispatcher::globalDispatcher().registerGenerator(
unwrap(impl_type),
[genFunc, userData](ServiceImplementReqOp req,
ServiceDeclOpInterface decl, ServiceImplRecordOp) {
return unwrap(genFunc(wrap(req), wrap(decl.getOperation()), userData));
unwrap(impl_type), [genFunc, userData](ServiceImplementReqOp req,
ServiceDeclOpInterface decl,
ServiceImplRecordOp record) {
return unwrap(genFunc(wrap(req), wrap(decl.getOperation()),
wrap(record.getOperation()), userData));
});
}
//===----------------------------------------------------------------------===//
Expand Down
Loading