Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[flang] Inline hlfir.matmul[_transpose]. #122821

Merged
merged 3 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions flang/include/flang/Optimizer/Builder/FIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,15 @@ elideLengthsAlreadyInType(mlir::Type type, mlir::ValueRange lenParams);
/// Get the address space which should be used for allocas
uint64_t getAllocaAddressSpace(mlir::DataLayout *dataLayout);

/// The two vectors of MLIR values have the following property:
/// \p extents1[i] must have the same value as \p extents2[i]
/// The function returns a new vector of MLIR values that preserves
/// the same property vs \p extents1 and \p extents2, but allows
/// more optimizations. For example, if extents1[j] is a known constant,
/// and extents2[j] is not, then result[j] is the MLIR value extents1[j].
llvm::SmallVector<mlir::Value> deduceOptimalExtents(mlir::ValueRange extents1,
mlir::ValueRange extents2);

} // namespace fir::factory

#endif // FORTRAN_OPTIMIZER_BUILDER_FIRBUILDER_H
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,11 @@ genTypeAndKindConvert(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity source, mlir::Type toType,
bool preserveLowerBounds);

/// A shortcut for loadTrivialScalar(getElementAt()),
/// which designates and loads an element of an array.
Entity loadElementAt(mlir::Location loc, fir::FirOpBuilder &builder,
Entity entity, mlir::ValueRange oneBasedIndices);

} // namespace hlfir

#endif // FORTRAN_OPTIMIZER_BUILDER_HLFIRTOOLS_H
11 changes: 11 additions & 0 deletions flang/include/flang/Optimizer/HLFIR/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ def LowerHLFIROrderedAssignments : Pass<"lower-hlfir-ordered-assignments", "::ml

def SimplifyHLFIRIntrinsics : Pass<"simplify-hlfir-intrinsics"> {
let summary = "Simplify HLFIR intrinsic operations that don't need to result in runtime calls";
let options = [Option<"allowNewSideEffects", "allow-new-side-effects", "bool",
/*default=*/"false",
"If enabled, then the HLFIR operations simplification "
"may introduce operations with side effects. "
"For example, hlfir.matmul may be inlined as "
"and hlfir.eval_in_mem with hlfir.assign inside it."
"The hlfir.assign has a write effect on the memory "
"argument of hlfir.eval_in_mem, which may block "
"some existing MLIR transformations (e.g. CSE) "
"that otherwise would have been possible across "
"the hlfir.matmul.">];
}

def InlineElementals : Pass<"inline-elementals"> {
Expand Down
14 changes: 14 additions & 0 deletions flang/lib/Optimizer/Builder/FIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1740,3 +1740,17 @@ uint64_t fir::factory::getAllocaAddressSpace(mlir::DataLayout *dataLayout) {
return mlir::cast<mlir::IntegerAttr>(addrSpace).getUInt();
return 0;
}

llvm::SmallVector<mlir::Value>
fir::factory::deduceOptimalExtents(mlir::ValueRange extents1,
mlir::ValueRange extents2) {
llvm::SmallVector<mlir::Value> extents;
extents.reserve(extents1.size());
for (auto [extent1, extent2] : llvm::zip(extents1, extents2)) {
if (!fir::getIntIfConstant(extent1) && fir::getIntIfConstant(extent2))
extents.push_back(extent2);
else
extents.push_back(extent1);
}
return extents;
}
17 changes: 14 additions & 3 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,10 @@ llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
/*finalCountValue=*/false,
parentLoop.getRegionIterArgs());
// Return the results of the child loop from its parent loop.
builder.create<fir::ResultOp>(loc, doLoop.getResults());
if (!reductionInits.empty()) {
// Return the results of the child loop from its parent loop.
builder.create<fir::ResultOp>(loc, doLoop.getResults());
}
}

builder.setInsertionPointToStart(doLoop.getBody());
Expand All @@ -955,7 +957,8 @@ llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
reductionValues =
genBody(loc, builder, oneBasedIndices, parentLoop.getRegionIterArgs());
builder.setInsertionPointToEnd(parentLoop.getBody());
builder.create<fir::ResultOp>(loc, reductionValues);
if (!reductionValues.empty())
builder.create<fir::ResultOp>(loc, reductionValues);
builder.setInsertionPointAfter(outerLoop);
return outerLoop->getResults();
}
Expand Down Expand Up @@ -1410,3 +1413,11 @@ void hlfir::computeEvaluateOpIn(mlir::Location loc, fir::FirOpBuilder &builder,
builder.clone(op, mapper);
return;
}

hlfir::Entity hlfir::loadElementAt(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity entity,
mlir::ValueRange oneBasedIndices) {
return loadTrivialScalar(loc, builder,
getElementAt(loc, builder, entity, oneBasedIndices));
}
Loading
Loading