Skip to content

Commit

Permalink
Initialize pass options using pipeline options
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-kud authored and Maxim-Doronin committed Jun 5, 2024
1 parent 71f650b commit 83f6c04
Show file tree
Hide file tree
Showing 6 changed files with 148 additions and 0 deletions.
4 changes: 4 additions & 0 deletions mlir/include/mlir/Pass/Pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,10 @@ class Pass {
/// Copy the option values from 'other', which is another instance of this
/// pass.
void copyOptionValuesFrom(const Pass *other);

/// Copy the option values from 'other', which are PassPipeline options.
/// Here we copy only those options that have the same argument name.
void copyOptionValuesFrom(const detail::PassOptions &other);

private:
/// Out of line virtual method to ensure vtables and metadata are emitted to a
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Pass/PassOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ class PassOptions : protected llvm::cl::SubCommand {
/// Copy the option values from 'other' into 'this', where 'other' has the
/// same options as 'this'.
void copyOptionValuesFrom(const PassOptions &other);

/// Copy only those options that have the same argument name.
void matchAndCopyOptionValuesFrom(const PassOptions &otherPassOptions);

/// Parse options out as key=value pairs that can then be handed off to the
/// `llvm::cl` command line passing infrastructure. Everything is space
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/Pass/Pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ void Pass::copyOptionValuesFrom(const Pass *other) {
passOptions.copyOptionValuesFrom(other->passOptions);
}

/// Copy the option values from 'other', which are PassPipeline options.
/// Here we copy only those options that have the same argument name.
void Pass::copyOptionValuesFrom(const PassOptions &other) {
passOptions.matchAndCopyOptionValuesFrom(other);
}

/// Prints out the pass in the textual representation of pipelines. If this is
/// an adaptor pass, print its pass managers.
void Pass::printAsTextualPipeline(raw_ostream &os) {
Expand Down
13 changes: 13 additions & 0 deletions mlir/lib/Pass/PassRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,19 @@ void detail::PassOptions::copyOptionValuesFrom(const PassOptions &other) {
std::get<0>(optionsIt)->copyValueFrom(*std::get<1>(optionsIt));
}

/// Copy only those options that have the same argument name.
void detail::PassOptions::matchAndCopyOptionValuesFrom(const PassOptions &other) {
for (auto* optionsIt : other.options) {
const auto& it = llvm::find_if(options, [&](OptionBase * option) {
return option->getArgStr() == optionsIt->getArgStr();
});

if (it != options.end()) {
(*it)->copyValueFrom(*optionsIt);
}
}
}

/// Parse in the next argument from the given options string. Returns a tuple
/// containing [the key of the option, the value of the option, updated
/// `options` string pointing after the parsed option].
Expand Down
1 change: 1 addition & 0 deletions mlir/unittests/Pass/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_unittest(MLIRPassTests
AnalysisManagerTest.cpp
PassManagerTest.cpp
PassPipelineParserTest.cpp
PassPipelineOptionsTest.cpp
)
target_link_libraries(MLIRPassTests
PRIVATE
Expand Down
121 changes: 121 additions & 0 deletions mlir/unittests/Pass/PassPipelineOptionsTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
//===- PassPipelineParserTest.cpp - Pass Parser unit tests ----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "gtest/gtest.h"

#include <memory>

using namespace mlir;
using namespace mlir::detail;

namespace {

// these types are used for automatically generated code of pass
using StrPassOpt = ::mlir::Pass::Option<std::string>;
using IntPassOpt = ::mlir::Pass::Option<int>;
using BoolPassOpt = ::mlir::Pass::Option<bool>;

// these types are used for pipeline options that we manually pass to the constructor
using StrOption = mlir::detail::PassOptions::Option<std::string>;
using IntOption = mlir::detail::PassOptions::Option<int>;
using BoolOption = mlir::detail::PassOptions::Option<bool>;

const int intOptDefaultVal = 5;
const bool boolOptDefaultVal = true;

struct SimplePassWithOptions
: public PassWrapper<SimplePassWithOptions, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SimplePassWithOptions)

SimplePassWithOptions() = default;
SimplePassWithOptions(const SimplePassWithOptions &other) : PassWrapper(other) {}

SimplePassWithOptions(const detail::PassOptions& options) {
copyOptionValuesFrom(options);
}

LogicalResult initialize(MLIRContext *ctx) final {
return success();
}

void runOnOperation() override { }

public:
StrPassOpt strOpt{*this, "str-opt", ::llvm::cl::desc("string test option"), llvm::cl::init("")};
IntPassOpt intOpt{*this, "int-opt", ::llvm::cl::desc("int test option"), llvm::cl::init(intOptDefaultVal)};
BoolPassOpt boolOpt{*this, "bool-opt", ::llvm::cl::desc("bool test option"), llvm::cl::init(boolOptDefaultVal)};
};

TEST(PassPipelineOptionsTest, CopyAllOptions) {
struct DuplicatedOtions : ::mlir::PassPipelineOptions<DuplicatedOtions> {
StrOption strOpt{*this, "str-opt", ::llvm::cl::desc("string test option")};
IntOption intOpt{*this, "int-opt", ::llvm::cl::desc("int test option"), llvm::cl::init(intOptDefaultVal)};
BoolOption boolOpt{*this, "bool-opt", ::llvm::cl::desc("bool test option"), llvm::cl::init(boolOptDefaultVal)};
};

const auto expectedStrVal = "test1";
const auto expectedIntVal = -intOptDefaultVal;
const auto expectedBoolVal = !boolOptDefaultVal;

DuplicatedOtions options;
options.strOpt.setValue(expectedStrVal);
options.intOpt.setValue(expectedIntVal);
options.boolOpt.setValue(expectedBoolVal);

const auto& pass = std::make_unique<SimplePassWithOptions>(options);

EXPECT_EQ(pass->strOpt.getValue(), expectedStrVal);
EXPECT_EQ(pass->intOpt.getValue(), expectedIntVal);
EXPECT_EQ(pass->boolOpt.getValue(), expectedBoolVal);
}

TEST(PassPipelineOptionsTest, CopyMatchedOptions) {
struct SomePipelineOptions : ::mlir::PassPipelineOptions<SomePipelineOptions> {
StrOption strOpt{*this, "str-opt", ::llvm::cl::desc("string test option")};
IntOption intOpt{*this, "int-opt", ::llvm::cl::desc("int test option")};
StrOption anotherStrOpt{*this, "another-str-pipeline-opt",
::llvm::cl::desc("there is no such option in SimplePassWithOptions"), llvm::cl::init("anotherOptVal")};
IntOption anotherIntOpt{*this, "another-int-pipeline-opt",
::llvm::cl::desc("there is no such option in SimplePassWithOptions"), llvm::cl::init(10)};
};

const auto expectedStrVal = "test2";
const auto expectedIntVal = -intOptDefaultVal;

SomePipelineOptions options;
options.strOpt.setValue(expectedStrVal);
options.intOpt.setValue(expectedIntVal);

const auto pass = std::make_unique<SimplePassWithOptions>(options);

EXPECT_EQ(pass->strOpt.getValue(), expectedStrVal);
EXPECT_EQ(pass->intOpt.getValue(), expectedIntVal);
EXPECT_EQ(pass->boolOpt.getValue(), boolOptDefaultVal);
}

TEST(PassPipelineOptionsTest, NoMatchedOptions) {
struct SomePipelineOptions : ::mlir::PassPipelineOptions<SomePipelineOptions> {
StrOption anotherStrOpt{*this, "another-str-pipeline-opt",
::llvm::cl::desc("there is no such option in SimplePassWithOptions"), llvm::cl::init("anotherOptVal")};
IntOption anotherIntOpt{*this, "another-int-pipeline-opt",
::llvm::cl::desc("there is no such option in SimplePassWithOptions"), llvm::cl::init(10)};
};

SomePipelineOptions options;
const auto pass = std::make_unique<SimplePassWithOptions>(options);

EXPECT_EQ(pass->strOpt.getValue(), "");
EXPECT_EQ(pass->intOpt.getValue(), intOptDefaultVal);
EXPECT_EQ(pass->boolOpt.getValue(), boolOptDefaultVal);
}

} // namespace

0 comments on commit 83f6c04

Please sign in to comment.