Skip to content

Commit

Permalink
Revert "[mlir] share argument attributes interface between calls and …
Browse files Browse the repository at this point in the history
…callables (llvm#123176)"

This reverts commit 327d627.
  • Loading branch information
hanhanW committed Feb 4, 2025
1 parent c06d0ff commit 1a708bd
Show file tree
Hide file tree
Showing 32 changed files with 256 additions and 452 deletions.
4 changes: 1 addition & 3 deletions flang/include/flang/Optimizer/Dialect/CUF/CUFOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,7 @@ def cuf_KernelLaunchOp : cuf_Op<"kernel_launch", [CallOpInterface,
I32:$block_z,
Optional<I32>:$bytes,
Optional<I32>:$stream,
Variadic<AnyType>:$args,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
Variadic<AnyType>:$args
);

let assemblyFormat = [{
Expand Down
4 changes: 0 additions & 4 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2432,8 +2432,6 @@ def fir_CallOp : fir_Op<"call",
let arguments = (ins
OptionalAttr<SymbolRefAttr>:$callee,
Variadic<AnyType>:$args,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs,
DefaultValuedAttr<Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath
Expand Down Expand Up @@ -2520,8 +2518,6 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
fir_ClassType:$object,
Variadic<AnyType>:$args,
OptionalAttr<I32Attr>:$pass_arg_pos,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
OptionalAttr<fir_FortranProcedureFlagsAttr>:$procedure_attrs
);

Expand Down
12 changes: 4 additions & 8 deletions flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -594,8 +594,7 @@ Fortran::lower::genCallOpAndResult(

builder.create<cuf::KernelLaunchOp>(
loc, funcType.getResults(), funcSymbolAttr, grid_x, grid_y, grid_z,
block_x, block_y, block_z, bytes, stream, operands,
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
block_x, block_y, block_z, bytes, stream, operands);
callNumResults = 0;
} else if (caller.requireDispatchCall()) {
// Procedure call requiring a dynamic dispatch. Call is created with
Expand All @@ -622,8 +621,7 @@ Fortran::lower::genCallOpAndResult(
dispatch = builder.create<fir::DispatchOp>(
loc, funcType.getResults(), builder.getStringAttr(procName),
caller.getInputs()[*passArg], operands,
builder.getI32IntegerAttr(*passArg), /*arg_attrs=*/nullptr,
/*res_attrs=*/nullptr, procAttrs);
builder.getI32IntegerAttr(*passArg), procAttrs);
} else {
// NOPASS
const Fortran::evaluate::Component *component =
Expand All @@ -638,17 +636,15 @@ Fortran::lower::genCallOpAndResult(
passObject = builder.create<fir::LoadOp>(loc, passObject);
dispatch = builder.create<fir::DispatchOp>(
loc, funcType.getResults(), builder.getStringAttr(procName),
passObject, operands, nullptr, /*arg_attrs=*/nullptr,
/*res_attrs=*/nullptr, procAttrs);
passObject, operands, nullptr, procAttrs);
}
callNumResults = dispatch.getNumResults();
if (callNumResults != 0)
callResult = dispatch.getResult(0);
} else {
// Standard procedure call with fir.call.
auto call = builder.create<fir::CallOp>(
loc, funcType.getResults(), funcSymbolAttr, operands,
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr, procAttrs);
loc, funcType.getResults(), funcSymbolAttr, operands, procAttrs);

callNumResults = call.getNumResults();
if (callNumResults != 0)
Expand Down
2 changes: 0 additions & 2 deletions flang/lib/Optimizer/CodeGen/TargetRewrite.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
newOpers.insert(newOpers.end(), trailingOpers.begin(), trailingOpers.end());

llvm::SmallVector<mlir::Value, 1> newCallResults;
// TODO propagate/update call argument and result attributes.
if constexpr (std::is_same_v<std::decay_t<A>, mlir::gpu::LaunchFuncOp>) {
auto newCall = rewriter->create<A>(
loc, callOp.getKernel(), callOp.getGridSizeOperandValues(),
Expand Down Expand Up @@ -558,7 +557,6 @@ class TargetRewrite : public fir::impl::TargetRewritePassBase<TargetRewrite> {
loc, newResTys, rewriter->getStringAttr(callOp.getMethod()),
callOp.getOperands()[0], newOpers,
rewriter->getI32IntegerAttr(*callOp.getPassArgPos() + passArgShift),
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
callOp.getProcedureAttrsAttr());
if (wrap)
newCallResults.push_back((*wrap)(dispatchOp.getOperation()));
Expand Down
3 changes: 0 additions & 3 deletions flang/lib/Optimizer/Transforms/AbstractResult.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
newResultTypes.emplace_back(getVoidPtrType(result.getContext()));

Op newOp;
// TODO: propagate argument and result attributes (need to be shifted).
// fir::CallOp specific handling.
if constexpr (std::is_same_v<Op, fir::CallOp>) {
if (op.getCallee()) {
Expand Down Expand Up @@ -190,11 +189,9 @@ class CallConversion : public mlir::OpRewritePattern<Op> {
if (op.getPassArgPos())
passArgPos =
rewriter.getI32IntegerAttr(*op.getPassArgPos() + passArgShift);
// TODO: propagate argument and result attributes (need to be shifted).
newOp = rewriter.create<fir::DispatchOp>(
loc, newResultTypes, rewriter.getStringAttr(op.getMethod()),
op.getOperands()[0], newOperands, passArgPos,
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr,
op.getProcedureAttrsAttr());
}

Expand Down
5 changes: 2 additions & 3 deletions flang/lib/Optimizer/Transforms/PolymorphicOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,8 @@ struct DispatchOpConv : public OpConversionPattern<fir::DispatchOp> {
// Make the call.
llvm::SmallVector<mlir::Value> args{funcPtr};
args.append(dispatch.getArgs().begin(), dispatch.getArgs().end());
rewriter.replaceOpWithNewOp<fir::CallOp>(
dispatch, resTypes, nullptr, args, dispatch.getArgAttrsAttr(),
dispatch.getResAttrsAttr(), dispatch.getProcedureAttrsAttr());
rewriter.replaceOpWithNewOp<fir::CallOp>(dispatch, resTypes, nullptr, args,
dispatch.getProcedureAttrsAttr());
return mlir::success();
}

Expand Down
7 changes: 1 addition & 6 deletions mlir/docs/Interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -753,15 +753,10 @@ interface section goes as follows:
- (`C++ class` -- `ODS class`(if applicable))

##### CallInterfaces

* `CallOpInterface` - Used to represent operations like 'call'
- `CallInterfaceCallable getCallableForCallee()`
- `void setCalleeFromCallable(CallInterfaceCallable)`
- `ArrayAttr getArgAttrsAttr()`
- `ArrayAttr getResAttrsAttr()`
- `void setArgAttrsAttr(ArrayAttr)`
- `void setResAttrsAttr(ArrayAttr)`
- `Attribute removeArgAttrsAttr()`
- `Attribute removeResAttrsAttr()`
* `CallableOpInterface` - Used to represent the target callee of call.
- `Region * getCallableRegion()`
- `ArrayRef<Type> getArgumentTypes()`
Expand Down
7 changes: 1 addition & 6 deletions mlir/examples/toy/Ch4/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -215,12 +215,7 @@ def GenericCallOp : Toy_Op<"generic_call",

// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<F64Tensor>:$inputs,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);

// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
Expand Down
7 changes: 1 addition & 6 deletions mlir/examples/toy/Ch5/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,7 @@ def GenericCallOp : Toy_Op<"generic_call",

// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<F64Tensor>:$inputs,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);

// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
Expand Down
7 changes: 1 addition & 6 deletions mlir/examples/toy/Ch6/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,12 +214,7 @@ def GenericCallOp : Toy_Op<"generic_call",

// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<F64Tensor>:$inputs,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<F64Tensor>:$inputs);

// The generic call operation returns a single value of TensorType.
let results = (outs F64Tensor);
Expand Down
10 changes: 2 additions & 8 deletions mlir/examples/toy/Ch7/include/toy/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,7 @@ def GenericCallOp : Toy_Op<"generic_call",

// The generic call operation takes a symbol reference attribute as the
// callee, and inputs for the call.
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<Toy_Type>:$inputs,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);
let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<Toy_Type>:$inputs);

// The generic call operation returns a single value of TensorType or
// StructType.
Expand All @@ -255,8 +250,7 @@ def GenericCallOp : Toy_Op<"generic_call",

// Add custom build methods for the generic call operation.
let builders = [
OpBuilder<(ins "Type":$result_type, "StringRef":$callee,
"ArrayRef<Value>":$arguments)>
OpBuilder<(ins "StringRef":$callee, "ArrayRef<Value>":$arguments)>
];
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/examples/toy/Ch7/mlir/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ void FuncOp::print(mlir::OpAsmPrinter &p) {
//===----------------------------------------------------------------------===//

void GenericCallOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::Type resultType, StringRef callee,
ArrayRef<mlir::Value> arguments) {
state.addTypes(resultType);
StringRef callee, ArrayRef<mlir::Value> arguments) {
// Generic call always returns an unranked Tensor initially.
state.addTypes(UnrankedTensorType::get(builder.getF64Type()));
state.addOperands(arguments);
state.addAttribute("callee",
mlir::SymbolRefAttr::get(builder.getContext(), callee));
Expand Down
3 changes: 2 additions & 1 deletion mlir/examples/toy/Ch7/mlir/MLIRGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,8 @@ class MLIRGenImpl {
}
mlir::toy::FuncOp calledFunc = calledFuncIt->second;
return builder.create<GenericCallOp>(
location, calledFunc.getFunctionType().getResult(0), callee, operands);
location, calledFunc.getFunctionType().getResult(0),
mlir::SymbolRefAttr::get(builder.getContext(), callee), operands);
}

/// Emit a print expression. It emits specific operations for two builtins:
Expand Down
8 changes: 1 addition & 7 deletions mlir/include/mlir/Dialect/Async/IR/AsyncOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,7 @@ def Async_CallOp : Async_Op<"call",
```
}];

let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<AnyType>:$operands,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);

let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands);
let results = (outs Variadic<Async_AnyValueOrTokenType>);

let builders = [
Expand Down
8 changes: 1 addition & 7 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -551,13 +551,7 @@ def EmitC_CallOp : EmitC_Op<"call",
%2 = emitc.call @my_add(%0, %1) : (f32, f32) -> f32
```
}];
let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<EmitCType>:$operands,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);

let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<EmitCType>:$operands);
let results = (outs Variadic<EmitCType>);

let builders = [
Expand Down
31 changes: 4 additions & 27 deletions mlir/include/mlir/Dialect/Func/IR/FuncOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,8 @@ def CallOp : Func_Op<"call",
```
}];

let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<AnyType>:$operands,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
UnitAttr:$no_inline
);

let arguments = (ins FlatSymbolRefAttr:$callee, Variadic<AnyType>:$operands,
UnitAttr:$no_inline);
let results = (outs Variadic<AnyType>);

let builders = [
Expand All @@ -79,18 +73,6 @@ def CallOp : Func_Op<"call",
CArg<"ValueRange", "{}">:$operands), [{
build($_builder, $_state, StringAttr::get($_builder.getContext(), callee),
results, operands);
}]>,
OpBuilder<(ins "TypeRange":$results, "FlatSymbolRefAttr":$callee,
CArg<"ValueRange", "{}">:$operands), [{
build($_builder, $_state, callee, results, operands);
}]>,
OpBuilder<(ins "TypeRange":$results, "StringAttr":$callee,
CArg<"ValueRange", "{}">:$operands), [{
build($_builder, $_state, callee, results, operands);
}]>,
OpBuilder<(ins "TypeRange":$results, "StringRef":$callee,
CArg<"ValueRange", "{}">:$operands), [{
build($_builder, $_state, callee, results, operands);
}]>];

let extraClassDeclaration = [{
Expand Down Expand Up @@ -154,13 +136,8 @@ def CallIndirectOp : Func_Op<"call_indirect", [
```
}];

let arguments = (ins
FunctionType:$callee,
Variadic<AnyType>:$callee_operands,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
);

let arguments = (ins FunctionType:$callee,
Variadic<AnyType>:$callee_operands);
let results = (outs Variadic<AnyType>:$results);

let builders = [
Expand Down
6 changes: 1 addition & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,6 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
OptionalAttr<FlatSymbolRefAttr>:$callee,
Variadic<LLVM_Type>:$callee_operands,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs,
Variadic<LLVM_Type>:$normalDestOperands,
Variadic<LLVM_Type>:$unwindDestOperands,
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
Expand Down Expand Up @@ -757,9 +755,7 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
VariadicOfVariadic<LLVM_Type,
"op_bundle_sizes">:$op_bundle_operands,
DenseI32ArrayAttr:$op_bundle_sizes,
OptionalAttr<ArrayAttr>:$op_bundle_tags,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
OptionalAttr<ArrayAttr>:$op_bundle_tags);
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
let arguments = !con(args, aliasAttrs);
let results = (outs Optional<LLVM_Type>:$result);
Expand Down
13 changes: 1 addition & 12 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -214,24 +214,13 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [

let arguments = (ins
FlatSymbolRefAttr:$callee,
Variadic<SPIRV_Type>:$arguments,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs
Variadic<SPIRV_Type>:$arguments
);

let results = (outs
Optional<SPIRV_Type>:$return_value
);

let builders = [
OpBuilder<(ins "Type":$returnType, "FlatSymbolRefAttr":$callee,
"ValueRange":$arguments),
[{
build($_builder, $_state, returnType, callee, arguments,
/*arg_attrs=*/nullptr, /*res_attrs=*/nullptr);
}]>
];

let autogenSerialization = 0;

let assemblyFormat = [{
Expand Down
4 changes: 1 addition & 3 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -886,9 +886,7 @@ def IncludeOp : TransformDialectOp<"include",

let arguments = (ins SymbolRefAttr:$target,
FailurePropagationMode:$failure_propagation_mode,
Variadic<Transform_AnyHandleOrParamType>:$operands,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
Variadic<Transform_AnyHandleOrParamType>:$operands);
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);

let assemblyFormat =
Expand Down
Loading

0 comments on commit 1a708bd

Please sign in to comment.