diff --git a/include/xten/Dialect/XTenNN/IR/CMakeLists.txt b/include/xten/Dialect/XTenNN/IR/CMakeLists.txt index c9b122cd..43f42ede 100644 --- a/include/xten/Dialect/XTenNN/IR/CMakeLists.txt +++ b/include/xten/Dialect/XTenNN/IR/CMakeLists.txt @@ -12,6 +12,8 @@ mlir_tablegen(XTenNNBase.h.inc -gen-dialect-decls) mlir_tablegen(XTenNNBase.cpp.inc -gen-dialect-defs) mlir_tablegen(XTenNNOps.h.inc -gen-op-decls) mlir_tablegen(XTenNNOps.cpp.inc -gen-op-defs) +mlir_tablegen(XTenNNEnums.h.inc -gen-enum-decls) +mlir_tablegen(XTenNNEnums.cpp.inc -gen-enum-defs) add_public_tablegen_target(XTenNNDialectIncGen) add_dependencies(XTenNNIncGen XTenNNDialectIncGen) diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.h b/include/xten/Dialect/XTenNN/IR/XTenNNOps.h index ad8e4ecb..8d44cebc 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.h +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.h @@ -22,6 +22,9 @@ #include "mlir/Interfaces/SideEffectInterfaces.h" #include "xten/Dialect/XTenNN/Interfaces/EnclaveOpInterfaces.h" +// Include enums +#include "xten/Dialect/XTenNN/IR/XTenNNEnums.h.inc" + namespace mlir::OpTrait { template class TosaExtension : public TraitBase {}; diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 00106456..93d01b5c 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -16,6 +16,7 @@ include "xten/Dialect/XTenNN/IR/XTenNNBase.td" include "xten/Dialect/XTenNN/IR/XTenNNTypes.td" +include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/FunctionInterfaces.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/RegionKindInterface.td" @@ -35,6 +36,29 @@ class XTenNN_Op traits = []> // SubgraphOp //===----------------------------------------------------------------------===// +def MicroKernel: I32EnumAttrCase<"MicroKernel", 0>; +def TemplatedGraph: I32EnumAttrCase<"TemplatedGraph", 1>; +def InCoreChain: I32EnumAttrCase<"InCoreChain", 2>; +def MllibKernel: I32EnumAttrCase<"MllibKernel", 3>; +def PseudoOp: I32EnumAttrCase<"PseudoOp", 4>; +def MemTileChain: I32EnumAttrCase<"MemTileChain", 5>; +def Interface: I32EnumAttrCase<"Interface", 6>; +def CpuBecause: I32EnumAttrCase<"CpuBecause", 7>; + +def ReasonEnum : I32EnumAttr<"ReasonEnum", "List of possible reasons for a subgraph to exist", + [MicroKernel, + TemplatedGraph, + InCoreChain, + MllibKernel, + PseudoOp, + MemTileChain, + Interface, + CpuBecause]> { + let cppNamespace = "amd::xten_nn"; + let stringToSymbolFnName = "reasonStrToEnum"; + let symbolToStringFnName = "toString"; +} + def XTenNN_SubgraphOp : XTenNN_Op<"subgraph", [ DeclareOpInterfaceMethods, @@ -71,7 +95,7 @@ def XTenNN_SubgraphOp : XTenNN_Op<"subgraph", [ ``` }]; - let arguments = (ins Variadic:$captures); + let arguments = (ins Variadic:$captures, StrAttr:$Reason); let results = (outs Variadic:$results); let regions = (region MaxSizedRegion<1>:$content); diff --git a/lib/Dialect/XTenNN/IR/CMakeLists.txt b/lib/Dialect/XTenNN/IR/CMakeLists.txt index 5b468f08..258647cb 100644 --- a/lib/Dialect/XTenNN/IR/CMakeLists.txt +++ b/lib/Dialect/XTenNN/IR/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_dialect_library(XTenNNDialect XTenNNBase.cpp + XTenNNEnums.cpp XTenNNOps.cpp DEPENDS diff --git a/lib/Dialect/XTenNN/IR/XTenNNEnums.cpp b/lib/Dialect/XTenNN/IR/XTenNNEnums.cpp new file mode 100644 index 00000000..d9067b82 --- /dev/null +++ b/lib/Dialect/XTenNN/IR/XTenNNEnums.cpp @@ -0,0 +1,16 @@ +//===- XTenNNEnums.cpp ------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +// (c) Copyright 2024 Advanced Micro Devices, Inc. All Rights reserved. + +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/BuiltinOps.h" + +#include "xten/Dialect/XTenNN/IR/XTenNNEnums.h.inc" +#include "xten/Dialect/XTenNN/IR/XTenNNEnums.cpp.inc" diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index 83afc80a..1384bf69 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -212,6 +212,27 @@ LogicalResult SubgraphOp::verify() { << ") does not match argument type (" << argType << ")"; } } + + std::optional reason = reasonStrToEnum(getReason()); + if (!reason) { + SmallVector validReasons; + validReasons.reserve(getMaxEnumValForReasonEnum()); + + for (unsigned i = 0; i <= getMaxEnumValForReasonEnum(); i++) { + validReasons.push_back(toString(static_cast(i))); + } + + std::string commaSeparatedReasons; + { + llvm::raw_string_ostream rso(commaSeparatedReasons); + llvm::interleaveComma(validReasons, rso); + } + + return emitOpError() << "invalid provided Reason '" << getReason() + << "'. Valid Reasons are: [" << commaSeparatedReasons + << "]"; + } + return success(); }