Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arc][Sim] Lower Sim DPI func to func.func and support dpi call in Arc #7386

Merged
merged 7 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/circt/Dialect/Sim/SimPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,9 @@ def ProceduralizeSim : Pass<"sim-proceduralize", "hw::HWModuleOp"> {
let dependentDialects = ["circt::hw::HWDialect, circt::seq::SeqDialect, mlir::scf::SCFDialect"];
}

def LowerDPIFunc : Pass<"sim-lower-dpi-func", "mlir::ModuleOp"> {
let summary = "Lower sim.dpi.func into func.func for the simulation flow";
let dependentDialects = ["mlir::func::FuncDialect", "mlir::LLVM::LLVMDialect"];
}

#endif // CIRCT_DIALECT_SIM_SEQPASSES
39 changes: 39 additions & 0 deletions integration_test/arcilator/JIT/dpi.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: arcilator %s --run --jit-entry=main | FileCheck %s
// REQUIRES: arcilator-jit

// CHECK: c = 0
// CHECK-NEXT: c = 5
sim.func.dpi @dpi(in %a : i32, in %b : i32, out c : i32) attributes {verilogName = "adder_func"}
func.func @adder_func(%arg0: i32, %arg1: i32, %arg2: !llvm.ptr) {
%0 = arith.addi %arg0, %arg1 : i32
llvm.store %0, %arg2 : i32, !llvm.ptr
return
}
hw.module @adder(in %clock : i1, in %a : i32, in %b : i32, out c : i32) {
%seq_clk = seq.to_clock %clock

%0 = sim.func.dpi.call @dpi(%a, %b) clock %seq_clk : (i32, i32) -> i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity: since the sim.func.dpi.call function seems like it accepts both sim.func.dpi and func.func callables, would this code also work with a second sim.func.dpi.call that calls the @adder_func directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly calling @adder_func doesn't work since it's output value is passed by reference in argument. If @adder_func was a normal dataflow-ish func.func such as func.func @adder_func_ok(%arg0: i32, %arg1: i32) -> i32, we can call it directly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooooh you're right, sorry, I overlooked that difference in the operand layout. Makes sense!

hw.output %0 : i32
}
func.func @main() {
%c2_i32 = arith.constant 2 : i32
%c3_i32 = arith.constant 3 : i32
%one = arith.constant 1 : i1
%zero = arith.constant 0 : i1
arc.sim.instantiate @adder as %arg0 {
arc.sim.set_input %arg0, "a" = %c2_i32 : i32, !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "b" = %c3_i32 : i32, !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@adder>

arc.sim.step %arg0 : !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %zero : i1, !arc.sim.instance<@adder>
%0 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@adder>
arc.sim.emit "c", %0 : i32

arc.sim.step %arg0 : !arc.sim.instance<@adder>
arc.sim.set_input %arg0, "clock" = %one : i1, !arc.sim.instance<@adder>
%2 = arc.sim.get_port %arg0, "c" : i32, !arc.sim.instance<@adder>
arc.sim.emit "c", %2 : i32
}
return
}
1 change: 1 addition & 0 deletions lib/Conversion/ConvertToArcs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ add_circt_conversion_library(CIRCTConvertToArcs
CIRCTArc
CIRCTHW
CIRCTSeq
CIRCTSim
MLIRTransforms
)
3 changes: 2 additions & 1 deletion lib/Conversion/ConvertToArcs/ConvertToArcs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "circt/Dialect/Arc/ArcOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Support/Namespace.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand All @@ -25,7 +26,7 @@ using llvm::MapVector;
static bool isArcBreakingOp(Operation *op) {
return op->hasTrait<OpTrait::ConstantLike>() ||
isa<hw::InstanceOp, seq::CompRegOp, MemoryOp, ClockedOpInterface,
seq::ClockGateOp>(op) ||
seq::ClockGateOp, sim::DPICallOp>(op) ||
op->getNumResults() > 1;
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Arc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ add_circt_dialect_library(CIRCTArcTransforms
CIRCTOM
CIRCTSV
CIRCTSeq
CIRCTSim
CIRCTSupport
MLIRFuncDialect
MLIRLLVMDialect
Expand Down
93 changes: 60 additions & 33 deletions lib/Dialect/Arc/Transforms/LowerState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Support/BackedgeBuilder.h"
#include "mlir/Analysis/TopologicalSortUtils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
Expand Down Expand Up @@ -117,7 +118,12 @@ struct ModuleLowering {
LogicalResult lowerPrimaryInputs();
LogicalResult lowerPrimaryOutputs();
LogicalResult lowerStates();
template <typename CallTy>
LogicalResult lowerStateLike(Operation *op, Value clock, Value enable,
Value reset, ArrayRef<Value> inputs,
FlatSymbolRefAttr callee);
LogicalResult lowerState(StateOp stateOp);
LogicalResult lowerState(sim::DPICallOp dpiCallOp);
LogicalResult lowerState(MemoryOp memOp);
LogicalResult lowerState(MemoryWritePortOp memWriteOp);
LogicalResult lowerState(TapOp tapOp);
Expand All @@ -139,7 +145,7 @@ static bool shouldMaterialize(Operation *op) {
return !isa<MemoryOp, AllocStateOp, AllocMemoryOp, AllocStorageOp,
ClockTreeOp, PassThroughOp, RootInputOp, RootOutputOp,
StateWriteOp, MemoryWritePortOp, igraph::InstanceOpInterface,
StateOp>(op);
StateOp, sim::DPICallOp>(op);
}

static bool shouldMaterialize(Value value) {
Expand Down Expand Up @@ -390,53 +396,48 @@ LogicalResult ModuleLowering::lowerPrimaryOutputs() {
LogicalResult ModuleLowering::lowerStates() {
SmallVector<Operation *> opsToLower;
for (auto &op : *moduleOp.getBodyBlock())
if (isa<StateOp, MemoryOp, MemoryWritePortOp, TapOp>(&op))
if (isa<StateOp, MemoryOp, MemoryWritePortOp, TapOp, sim::DPICallOp>(&op))
opsToLower.push_back(&op);

for (auto *op : opsToLower) {
LLVM_DEBUG(llvm::dbgs() << "- Lowering " << *op << "\n");
auto result = TypeSwitch<Operation *, LogicalResult>(op)
.Case<StateOp, MemoryOp, MemoryWritePortOp, TapOp>(
[&](auto op) { return lowerState(op); })
.Default(success());
auto result =
TypeSwitch<Operation *, LogicalResult>(op)
.Case<StateOp, MemoryOp, MemoryWritePortOp, TapOp, sim::DPICallOp>(
[&](auto op) { return lowerState(op); })
.Default(success());
if (failed(result))
return failure();
}
return success();
}

LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
// We don't support arcs beyond latency 1 yet. These should be easy to add in
// the future though.
if (stateOp.getLatency() > 1)
return stateOp.emitError("state with latency > 1 not supported");

// Grab all operands from the state op and make it drop all its references.
// This allows `materializeValue` to move an operation if this state was the
// last user.
auto stateClock = stateOp.getClock();
auto stateEnable = stateOp.getEnable();
auto stateReset = stateOp.getReset();
auto stateInputs = SmallVector<Value>(stateOp.getInputs());
template <typename CallOpTy>
LogicalResult ModuleLowering::lowerStateLike(
Operation *stateOp, Value stateClock, Value stateEnable, Value stateReset,
ArrayRef<Value> stateInputs, FlatSymbolRefAttr callee) {
// Grab all operands from the state op at the callsite and make it drop all
// its references. This allows `materializeValue` to move an operation if this
// state was the last user.

// Get the clock tree and enable condition for this state's clock. If this arc
// carries an explicit enable condition, fold that into the enable provided by
// the clock gates in the arc's clock tree.
auto info = getOrCreateClockLowering(stateClock);
info.enable = info.clock.getOrCreateAnd(
info.enable, info.clock.materializeValue(stateEnable), stateOp.getLoc());
info.enable, info.clock.materializeValue(stateEnable), stateOp->getLoc());

// Allocate the necessary state within the model.
SmallVector<Value> allocatedStates;
for (unsigned stateIdx = 0; stateIdx < stateOp.getNumResults(); ++stateIdx) {
auto type = stateOp.getResult(stateIdx).getType();
for (unsigned stateIdx = 0; stateIdx < stateOp->getNumResults(); ++stateIdx) {
auto type = stateOp->getResult(stateIdx).getType();
auto intType = dyn_cast<IntegerType>(type);
if (!intType)
return stateOp.emitOpError("result ")
return stateOp->emitOpError("result ")
<< stateIdx << " has non-integer type " << type
<< "; only integer types are supported";
auto stateType = StateType::get(intType);
auto state = stateBuilder.create<AllocStateOp>(stateOp.getLoc(), stateType,
auto state = stateBuilder.create<AllocStateOp>(stateOp->getLoc(), stateType,
storageArg);
if (auto names = stateOp->getAttrOfType<ArrayAttr>("names"))
state->setAttr("name", names[stateIdx]);
Expand All @@ -455,18 +456,18 @@ LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
OpBuilder nonResetBuilder = info.clock.builder;
if (stateReset) {
auto materializedReset = info.clock.materializeValue(stateReset);
auto ifOp = info.clock.builder.create<scf::IfOp>(stateOp.getLoc(),
auto ifOp = info.clock.builder.create<scf::IfOp>(stateOp->getLoc(),
materializedReset, true);

for (auto [alloc, resTy] :
llvm::zip(allocatedStates, stateOp.getResultTypes())) {
llvm::zip(allocatedStates, stateOp->getResultTypes())) {
if (!isa<IntegerType>(resTy))
stateOp->emitOpError("Non-integer result not supported yet!");

auto thenBuilder = ifOp.getThenBodyBuilder();
Value constZero =
thenBuilder.create<hw::ConstantOp>(stateOp.getLoc(), resTy, 0);
thenBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, constZero,
thenBuilder.create<hw::ConstantOp>(stateOp->getLoc(), resTy, 0);
thenBuilder.create<StateWriteOp>(stateOp->getLoc(), alloc, constZero,
Value());
}

Expand All @@ -475,24 +476,50 @@ LogicalResult ModuleLowering::lowerState(StateOp stateOp) {

stateOp->dropAllReferences();

auto newStateOp = nonResetBuilder.create<CallOp>(
stateOp.getLoc(), stateOp.getResultTypes(), stateOp.getArcAttr(),
auto newStateOp = nonResetBuilder.create<CallOpTy>(
stateOp->getLoc(), stateOp->getResultTypes(), callee,
materializedOperands);

// Create the write ops that write the result of the transfer function to the
// allocated state storage.
for (auto [alloc, result] :
llvm::zip(allocatedStates, newStateOp.getResults()))
nonResetBuilder.create<StateWriteOp>(stateOp.getLoc(), alloc, result,
nonResetBuilder.create<StateWriteOp>(stateOp->getLoc(), alloc, result,
info.enable);

// Replace all uses of the arc with reads from the allocated state.
for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp.getResults()))
for (auto [alloc, result] : llvm::zip(allocatedStates, stateOp->getResults()))
replaceValueWithStateRead(result, alloc);
stateOp.erase();
stateOp->erase();
return success();
}

LogicalResult ModuleLowering::lowerState(StateOp stateOp) {
// We don't support arcs beyond latency 1 yet. These should be easy to add in
// the future though.
if (stateOp.getLatency() > 1)
return stateOp.emitError("state with latency > 1 not supported");

auto stateInputs = SmallVector<Value>(stateOp.getInputs());

return lowerStateLike<arc::CallOp>(stateOp, stateOp.getClock(),
stateOp.getEnable(), stateOp.getReset(),
stateInputs, stateOp.getArcAttr());
}

LogicalResult ModuleLowering::lowerState(sim::DPICallOp callOp) {
// Clocked call op can be considered as arc state with single latency.
auto stateClock = callOp.getClock();
if (!stateClock)
return callOp.emitError("unclocked DPI call not implemented yet");

auto stateInputs = SmallVector<Value>(callOp.getInputs());

return lowerStateLike<func::CallOp>(callOp, stateClock, callOp.getEnable(),
Value(), stateInputs,
callOp.getCalleeAttr());
}

LogicalResult ModuleLowering::lowerState(MemoryOp memOp) {
auto allocMemOp = stateBuilder.create<AllocMemoryOp>(
memOp.getLoc(), memOp.getType(), storageArg, memOp->getAttrs());
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Sim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ add_circt_dialect_library(CIRCTSim
CIRCTHW
CIRCTSeq
CIRCTSV
MLIRFuncDialect
MLIRIR
MLIRPass
MLIRTransforms
Expand Down
10 changes: 7 additions & 3 deletions lib/Dialect/Sim/SimOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "circt/Dialect/Sim/SimOps.h"
#include "circt/Dialect/HW/ModuleImplementation.h"
#include "circt/Dialect/SV/SVOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/FunctionImplementation.h"

Expand Down Expand Up @@ -69,12 +70,15 @@ ParseResult DPIFuncOp::parse(OpAsmParser &parser, OperationState &result) {

LogicalResult
sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto referencedOp = dyn_cast_or_null<sim::DPIFuncOp>(
symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr()));
auto referencedOp =
symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr());
if (!referencedOp)
return emitError("cannot find function declaration '")
<< getCallee() << "'";
return success();
if (isa<func::FuncOp, sim::DPIFuncOp>(referencedOp))
return success();
return emitError("callee must be 'sim.dpi.func' or 'func.func' but got '")
<< referencedOp->getName() << "'";
}

void DPIFuncOp::print(OpAsmPrinter &p) {
Expand Down
3 changes: 3 additions & 0 deletions lib/Dialect/Sim/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
add_circt_dialect_library(CIRCTSimTransforms
LowerDPIFunc.cpp
ProceduralizeSim.cpp


Expand All @@ -12,8 +13,10 @@ add_circt_dialect_library(CIRCTSimTransforms
CIRCTSV
CIRCTComb
CIRCTSupport
MLIRFuncDialect
MLIRIR
MLIRPass
MLIRLLVMDialect
MLIRSCFDialect
MLIRTransformUtils
)
Loading
Loading