Skip to content

Commit

Permalink
[FIRRTL] LowerXMR: process all modules
Browse files Browse the repository at this point in the history
This changes LowerXMR to process all modules instead of just those that
are reachable from the top level module.
  • Loading branch information
youngar committed Feb 1, 2025
1 parent 671dec5 commit 0a437b9
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 55 deletions.
114 changes: 59 additions & 55 deletions lib/Dialect/FIRRTL/Transforms/LowerXMR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,69 +300,73 @@ class LowerXMRPass : public circt::firrtl::impl::LowerXMRBase<LowerXMRPass> {
SmallVector<FModuleOp> publicModules;

// Traverse the modules in post order.
for (auto node : llvm::post_order(&instanceGraph)) {
auto module = dyn_cast<FModuleOp>(*node->getModule());
if (!module)
continue;
LLVM_DEBUG(llvm::dbgs()
<< "Traversing module:" << module.getModuleNameAttr() << "\n");

moduleStates.insert({module, ModuleState(module)});
DenseSet<InstanceGraphNode *> visited;
for (auto *root : instanceGraph) {
for (auto *node : llvm::post_order_ext(root, visited)) {
auto module = dyn_cast<FModuleOp>(*node->getModule());
if (!module)
continue;
LLVM_DEBUG(llvm::dbgs() << "Traversing module:"
<< module.getModuleNameAttr() << "\n");

if (module.isPublic())
publicModules.push_back(module);
moduleStates.insert({module, ModuleState(module)});

auto result = module.walk([&](Operation *op) {
if (transferFunc(op).failed())
return WalkResult::interrupt();
return WalkResult::advance();
});
if (module.isPublic())
publicModules.push_back(module);

if (result.wasInterrupted())
return signalPassFailure();
auto result = module.walk([&](Operation *op) {
if (transferFunc(op).failed())
return WalkResult::interrupt();
return WalkResult::advance();
});

// Clear any enabled layers.
module.setLayersAttr(ArrayAttr::get(module.getContext(), {}));

// Since we walk operations pre-order and not along dataflow edges,
// ref.sub may not be resolvable when we encounter them (they're not just
// unification). This can happen when refs go through an output port or
// input instance result and back into the design. Handle these by walking
// them, resolving what we can, until all are handled or nothing can be
// resolved.
while (!indexingOps.empty()) {
// Grab the set of unresolved ref.sub's.
decltype(indexingOps) worklist;
worklist.swap(indexingOps);

for (auto op : worklist) {
auto inputEntry =
getRemoteRefSend(op.getInput(), /*errorIfNotFound=*/false);
// If we can't resolve, add back and move on.
if (!inputEntry)
indexingOps.push_back(op);
else
addReachingSendsEntry(op.getResult(), op.getOperation(),
inputEntry);
}
// If nothing was resolved, give up.
if (worklist.size() == indexingOps.size()) {
auto op = worklist.front();
getRemoteRefSend(op.getInput());
op.emitError(
"indexing through probe of unknown origin (input probe?)")
.attachNote(op.getInput().getLoc())
.append("indexing through this reference");
if (result.wasInterrupted())
return signalPassFailure();
}
}

// Record all the RefType ports to be removed later.
size_t numPorts = module.getNumPorts();
for (size_t portNum = 0; portNum < numPorts; ++portNum)
if (isa<RefType>(module.getPortType(portNum))) {
setPortToRemove(module, portNum, numPorts);
// Clear any enabled layers.
module.setLayersAttr(ArrayAttr::get(module.getContext(), {}));

// Since we walk operations pre-order and not along dataflow edges,
// ref.sub may not be resolvable when we encounter them (they're not
// just unification). This can happen when refs go through an output
// port or input instance result and back into the design. Handle these
// by walking them, resolving what we can, until all are handled or
// nothing can be resolved.
while (!indexingOps.empty()) {
// Grab the set of unresolved ref.sub's.
decltype(indexingOps) worklist;
worklist.swap(indexingOps);

for (auto op : worklist) {
auto inputEntry =
getRemoteRefSend(op.getInput(), /*errorIfNotFound=*/false);
// If we can't resolve, add back and move on.
if (!inputEntry)
indexingOps.push_back(op);
else
addReachingSendsEntry(op.getResult(), op.getOperation(),
inputEntry);
}
// If nothing was resolved, give up.
if (worklist.size() == indexingOps.size()) {
auto op = worklist.front();
getRemoteRefSend(op.getInput());
op.emitError(
"indexing through probe of unknown origin (input probe?)")
.attachNote(op.getInput().getLoc())
.append("indexing through this reference");
return signalPassFailure();
}
}

// Record all the RefType ports to be removed later.
size_t numPorts = module.getNumPorts();
for (size_t portNum = 0; portNum < numPorts; ++portNum)
if (isa<RefType>(module.getPortType(portNum))) {
setPortToRemove(module, portNum, numPorts);
}
}
}

LLVM_DEBUG({
Expand Down
23 changes: 23 additions & 0 deletions test/Dialect/FIRRTL/lowerXMR.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -796,3 +796,26 @@ firrtl.circuit "Foo" {
}
}
}

// -----
// Test that all modules are reached and updated.

// CHECK-LABEL: firrtl.circuit "PF"
firrtl.circuit "PF" {
// CHECK: @Child()
firrtl.module @Child(out %p: !firrtl.probe<uint<1>>) {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%0 = firrtl.ref.send %c1_ui1 : !firrtl.uint<1>
firrtl.ref.define %p, %0 : !firrtl.probe<uint<1>>
}
// CHECK: @PF()
firrtl.module @PF(out %p: !firrtl.probe<uint<1>>) {
%c_p = firrtl.instance c @Child(out p: !firrtl.probe<uint<1>>)
firrtl.ref.define %p, %c_p : !firrtl.probe<uint<1>>
}
// CHECK: @Other()
firrtl.module @Other(out %p: !firrtl.probe<uint<1>>) {
%c_p = firrtl.instance c @Child(out p: !firrtl.probe<uint<1>>)
firrtl.ref.define %p, %c_p : !firrtl.probe<uint<1>>
}
}

0 comments on commit 0a437b9

Please sign in to comment.