Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][OpenMP] Simplify definition of the BlockArgOpenMPOpInterface, NFC #128198

Merged
merged 2 commits into from
Feb 24, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Feb 21, 2025

This patch removes code duplication from the definition of methods of the BlockArgOpenMPOpInterface and makes the order relationship between entry block argument generating clauses explicit.

The goal of this change is to make the addition of clauses and methods to the interface less error-prone.

This patch removes code duplication from the definition of methods of the
`BlockArgOpenMPOpInterface` and makes the order relationship between entry
block argument generating clauses explicit.

The goal of this change is to make the addition of clauses and methods to the
interface less error-prone.
@llvmbot
Copy link
Member

llvmbot commented Feb 21, 2025

@llvm/pr-subscribers-mlir-openmp

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

Changes

This patch removes code duplication from the definition of methods of the BlockArgOpenMPOpInterface and makes the order relationship between entry block argument generating clauses explicit.

The goal of this change is to make the addition of clauses and methods to the interface less error-prone.


Full diff: https://github.com/llvm/llvm-project/pull/128198.diff

1 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td (+70-143)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
index c863e5772032c..3d838901a85f3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
@@ -15,6 +15,62 @@
 
 include "mlir/IR/OpBase.td"
 
+
+// Internal class to hold definitions of BlockArgOpenMPOpInterface methods,
+// based on the name of the clause and what clause comes earlier in the list.
+//
+// The clause order will define the expected relative order between block
+// arguments corresponding to each of these clauses.
+class BlockArgOpenMPClause<string clauseNameSnake, string clauseNameCamel,
+    BlockArgOpenMPClause previousClause> {
+  // Default-implemented method to be overriden by the corresponding clause.
+  InterfaceMethod numArgsMethod = InterfaceMethod<
+    "Get number of block arguments defined by `" # clauseNameSnake # "`.",
+    "unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{
+      return 0;
+    }]
+  >;
+
+  // Unified access method for the start index of clause-associated entry block
+  // arguments.
+  InterfaceMethod startMethod = InterfaceMethod<
+    "Get start index of block arguments defined by `" # clauseNameSnake # "`.",
+    "unsigned", "get" # clauseNameCamel # "BlockArgsStart", (ins),
+    !if(!initialized(previousClause), [{
+        auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      }] # "return iface." # previousClause.startMethod.name # "() + $_op."
+        # previousClause.numArgsMethod.name # "();",
+        "return 0;"
+    )
+  >;
+
+  // Unified access method for clause-associated entry block arguments.
+  InterfaceMethod blockArgsMethod = InterfaceMethod<
+    "Get block arguments defined by `" # clauseNameSnake # "`.",
+    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
+    "get" # clauseNameCamel # "BlockArgs", (ins), [{
+      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
+      return $_op->getRegion(0).getArguments().slice(
+    }] # "iface." # startMethod.name # "(), $_op." # numArgsMethod.name # "());"
+  >;
+}
+
+def BlockArgHostEvalClause : BlockArgOpenMPClause<"host_eval", "HostEval", ?>;
+def BlockArgInReductionClause : BlockArgOpenMPClause<
+    "in_reduction", "InReduction", BlockArgHostEvalClause>;
+def BlockArgMapClause : BlockArgOpenMPClause<
+    "map", "Map", BlockArgInReductionClause>;
+def BlockArgPrivateClause : BlockArgOpenMPClause<
+    "private", "Private", BlockArgMapClause>;
+def BlockArgReductionClause : BlockArgOpenMPClause<
+    "reduction", "Reduction", BlockArgPrivateClause>;
+def BlockArgTaskReductionClause : BlockArgOpenMPClause<
+    "task_reduction", "TaskReduction", BlockArgReductionClause>;
+def BlockArgUseDeviceAddrClause : BlockArgOpenMPClause<
+    "use_device_addr", "UseDeviceAddr", BlockArgTaskReductionClause>;
+def BlockArgUseDevicePtrClause : BlockArgOpenMPClause<
+    "use_device_ptr", "UseDevicePtr", BlockArgUseDeviceAddrClause>;
+
 def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
   let description = [{
     OpenMP operations that define entry block arguments as part of the
@@ -23,153 +79,24 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
 
   let cppNamespace = "::mlir::omp";
 
-  let methods = [
-    // Default-implemented methods to be overriden by the corresponding clauses.
-    InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
-                    "unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
-      return 0;
-    }]>,
-    InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
-                    "unsigned", "numInReductionBlockArgs", (ins), [{}], [{
-      return 0;
-    }]>,
-    InterfaceMethod<"Get number of block arguments defined by `map`.",
-                    "unsigned", "numMapBlockArgs", (ins), [{}], [{
-      return 0;
-    }]>,
-    InterfaceMethod<"Get number of block arguments defined by `private`.",
-                    "unsigned", "numPrivateBlockArgs", (ins), [{}], [{
-      return 0;
-    }]>,
-    InterfaceMethod<"Get number of block arguments defined by `reduction`.",
-                    "unsigned", "numReductionBlockArgs", (ins), [{}], [{
-      return 0;
-    }]>,
-    InterfaceMethod<"Get number of block arguments defined by `task_reduction`.",
-                    "unsigned", "numTaskReductionBlockArgs", (ins), [{}], [{
-      return 0;
-    }]>,
-    InterfaceMethod<"Get number of block arguments defined by `use_device_addr`.",
-                    "unsigned", "numUseDeviceAddrBlockArgs", (ins), [{}], [{
-      return 0;
-    }]>,
-    InterfaceMethod<"Get number of block arguments defined by `use_device_ptr`.",
-                    "unsigned", "numUseDevicePtrBlockArgs", (ins), [{}], [{
-      return 0;
-    }]>,
+  defvar clauses = [ BlockArgHostEvalClause, BlockArgInReductionClause,
+    BlockArgMapClause, BlockArgPrivateClause, BlockArgReductionClause,
+    BlockArgTaskReductionClause, BlockArgUseDeviceAddrClause,
+    BlockArgUseDevicePtrClause ];
 
-    // Unified access methods for start indices of clause-associated entry block
-    // arguments.
-    InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
-                    "unsigned", "getHostEvalBlockArgsStart", (ins), [{
-      return 0;
-    }]>,
-    InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
-                    "unsigned", "getInReductionBlockArgsStart", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
-    }]>,
-    InterfaceMethod<"Get start index of block arguments defined by `map`.",
-                    "unsigned", "getMapBlockArgsStart", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return iface.getInReductionBlockArgsStart() +
-             $_op.numInReductionBlockArgs();
-    }]>,
-    InterfaceMethod<"Get start index of block arguments defined by `private`.",
-                    "unsigned", "getPrivateBlockArgsStart", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return iface.getMapBlockArgsStart() + $_op.numMapBlockArgs();
-    }]>,
-    InterfaceMethod<"Get start index of block arguments defined by `reduction`.",
-                    "unsigned", "getReductionBlockArgsStart", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return iface.getPrivateBlockArgsStart() + $_op.numPrivateBlockArgs();
-    }]>,
-    InterfaceMethod<"Get start index of block arguments defined by `task_reduction`.",
-                    "unsigned", "getTaskReductionBlockArgsStart", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return iface.getReductionBlockArgsStart() + $_op.numReductionBlockArgs();
-    }]>,
-    InterfaceMethod<"Get start index of block arguments defined by `use_device_addr`.",
-                    "unsigned", "getUseDeviceAddrBlockArgsStart", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return iface.getTaskReductionBlockArgsStart() + $_op.numTaskReductionBlockArgs();
-    }]>,
-    InterfaceMethod<"Get start index of block arguments defined by `use_device_ptr`.",
-                    "unsigned", "getUseDevicePtrBlockArgsStart", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
-    }]>,
-
-    // Unified access methods for clause-associated entry block arguments.
-    InterfaceMethod<"Get block arguments defined by `host_eval`.",
-                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
-                    "getHostEvalBlockArgs", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return $_op->getRegion(0).getArguments().slice(
-          iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
-    }]>,
-    InterfaceMethod<"Get block arguments defined by `in_reduction`.",
-                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
-                    "getInReductionBlockArgs", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return $_op->getRegion(0).getArguments().slice(
-          iface.getInReductionBlockArgsStart(), $_op.numInReductionBlockArgs());
-    }]>,
-    InterfaceMethod<"Get block arguments defined by `map`.",
-                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
-                    "getMapBlockArgs", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return $_op->getRegion(0).getArguments().slice(
-          iface.getMapBlockArgsStart(), $_op.numMapBlockArgs());
-    }]>,
-    InterfaceMethod<"Get block arguments defined by `private`.",
-                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
-                    "getPrivateBlockArgs", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return $_op->getRegion(0).getArguments().slice(
-          iface.getPrivateBlockArgsStart(), $_op.numPrivateBlockArgs());
-    }]>,
-    InterfaceMethod<"Get block arguments defined by `reduction`.",
-                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
-                    "getReductionBlockArgs", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return $_op->getRegion(0).getArguments().slice(
-          iface.getReductionBlockArgsStart(), $_op.numReductionBlockArgs());
-    }]>,
-    InterfaceMethod<"Get block arguments defined by `task_reduction`.",
-                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
-                    "getTaskReductionBlockArgs", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return $_op->getRegion(0).getArguments().slice(
-          iface.getTaskReductionBlockArgsStart(),
-          $_op.numTaskReductionBlockArgs());
-    }]>,
-    InterfaceMethod<"Get block arguments defined by `use_device_addr`.",
-                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
-                    "getUseDeviceAddrBlockArgs", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return $_op->getRegion(0).getArguments().slice(
-          iface.getUseDeviceAddrBlockArgsStart(),
-          $_op.numUseDeviceAddrBlockArgs());
-    }]>,
-    InterfaceMethod<"Get block arguments defined by `use_device_ptr`.",
-                    "::llvm::MutableArrayRef<::mlir::BlockArgument>",
-                    "getUseDevicePtrBlockArgs", (ins), [{
-      auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
-      return $_op->getRegion(0).getArguments().slice(
-          iface.getUseDevicePtrBlockArgsStart(),
-          $_op.numUseDevicePtrBlockArgs());
-    }]>,
-  ];
+  let methods = !listconcat(
+    !foreach(clause, clauses, clause.numArgsMethod),
+    !foreach(clause, clauses, clause.startMethod),
+    !foreach(clause, clauses, clause.blockArgsMethod)
+  );
 
   let verify = [{
     auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
-    unsigned expectedArgs = iface.numHostEvalBlockArgs() +
-        iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
-        iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
-        iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
-        iface.numUseDevicePtrBlockArgs();
+  }] # "unsigned expectedArgs = "
+     # !interleave(
+         !foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"),
+         " + "
+       ) # ";" # [{
     if ($_op->getRegion(0).getNumArguments() < expectedArgs)
       return $_op->emitOpError() << "expected at least " << expectedArgs
                                  << " entry block argument(s)";

// Default-implemented method to be overriden by the corresponding clause.
InterfaceMethod numArgsMethod = InterfaceMethod<
"Get number of block arguments defined by `" # clauseNameSnake # "`.",
"unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment with an example of use? It's pretty clear in the code, but you need to know where to look...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I just added some usage examples, let me know if that's what you were looking for.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Thanks!

Copy link
Contributor

@kparzysz kparzysz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

// Default-implemented method to be overriden by the corresponding clause.
InterfaceMethod numArgsMethod = InterfaceMethod<
"Get number of block arguments defined by `" # clauseNameSnake # "`.",
"unsigned", "num" # clauseNameCamel # "BlockArgs", (ins), [{}], [{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Thanks!

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

I am worried that this is getting gradually harder and harder for anyone new to read the code and understand what is going on (for example somebody who only knows the basics of tablegen would have to think quite hard to know even what methods are defined), but considered on its own this patch is a clear maintainability improvement and the documentation @kparzysz suggested helps a lot.

@kparzysz kparzysz merged commit ff7790e into llvm:main Feb 24, 2025
11 checks passed
@skatrak skatrak deleted the entry-block-arg-iface-refactor branch February 24, 2025 16:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants