Skip to content

Commit 8fe03b0

Browse files
clementvalIcohedron
authored andcommitted
[flang][cuda] Lower syncwarp to NVVM intrinsic (llvm#126164)
1 parent 594e189 commit 8fe03b0

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

flang/include/flang/Optimizer/Builder/IntrinsicCall.h

+1
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ struct IntrinsicLibrary {
406406
mlir::Value genSyncThreadsAnd(mlir::Type, llvm::ArrayRef<mlir::Value>);
407407
mlir::Value genSyncThreadsCount(mlir::Type, llvm::ArrayRef<mlir::Value>);
408408
mlir::Value genSyncThreadsOr(mlir::Type, llvm::ArrayRef<mlir::Value>);
409+
void genSyncWarp(llvm::ArrayRef<fir::ExtendedValue>);
409410
fir::ExtendedValue genSystem(std::optional<mlir::Type>,
410411
mlir::ArrayRef<fir::ExtendedValue> args);
411412
void genSystemClock(llvm::ArrayRef<fir::ExtendedValue>);

flang/lib/Optimizer/Builder/IntrinsicCall.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ static constexpr IntrinsicHandler handlers[]{
680680
{"syncthreads_and", &I::genSyncThreadsAnd, {}, /*isElemental=*/false},
681681
{"syncthreads_count", &I::genSyncThreadsCount, {}, /*isElemental=*/false},
682682
{"syncthreads_or", &I::genSyncThreadsOr, {}, /*isElemental=*/false},
683+
{"syncwarp", &I::genSyncWarp, {}, /*isElemental=*/false},
683684
{"system",
684685
&I::genSystem,
685686
{{{"command", asBox}, {"exitstat", asBox, handleDynamicOptional}}},
@@ -7704,6 +7705,18 @@ IntrinsicLibrary::genSyncThreadsOr(mlir::Type resultType,
77047705
return builder.create<fir::CallOp>(loc, funcOp, args).getResult(0);
77057706
}
77067707

7708+
// SYNCWARP
7709+
void IntrinsicLibrary::genSyncWarp(llvm::ArrayRef<fir::ExtendedValue> args) {
7710+
assert(args.size() == 1);
7711+
constexpr llvm::StringLiteral funcName = "llvm.nvvm.bar.warp.sync";
7712+
mlir::Value mask = fir::getBase(args[0]);
7713+
mlir::FunctionType funcType =
7714+
mlir::FunctionType::get(builder.getContext(), {mask.getType()}, {});
7715+
auto funcOp = builder.createFunction(loc, funcName, funcType);
7716+
llvm::SmallVector<mlir::Value> argsList{mask};
7717+
builder.create<fir::CallOp>(loc, funcOp, argsList);
7718+
}
7719+
77077720
// SYSTEM
77087721
fir::ExtendedValue
77097722
IntrinsicLibrary::genSystem(std::optional<mlir::Type> resultType,

flang/module/cudadevice.f90

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ attributes(device) integer function syncthreads_or(value)
4949
public :: syncthreads_or
5050

5151
interface
52-
attributes(device) subroutine syncwarp(mask) bind(c, name='__syncwarp')
52+
attributes(device) subroutine syncwarp(mask)
5353
integer, value :: mask
5454
end subroutine
5555
end interface

flang/test/Lower/CUDA/cuda-device-proc.cuf

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747

4848
! CHECK-LABEL: func.func @_QPdevsub() attributes {cuf.proc_attr = #cuf.cuda_proc<global>}
4949
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
50-
! CHECK: fir.call @__syncwarp(%{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
50+
! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
5151
! CHECK: fir.call @llvm.nvvm.membar.gl() fastmath<contract> : () -> ()
5252
! CHECK: fir.call @llvm.nvvm.membar.cta() fastmath<contract> : () -> ()
5353
! CHECK: fir.call @llvm.nvvm.membar.sys() fastmath<contract> : () -> ()
@@ -102,13 +102,13 @@ end
102102
! CHECK-LABEL: func.func @_QPhost1()
103103
! CHECK: cuf.kernel
104104
! CHECK: fir.call @llvm.nvvm.barrier0() fastmath<contract> : () -> ()
105-
! CHECK: fir.call @__syncwarp(%c1{{.*}}) proc_attrs<bind_c> fastmath<contract> : (i32) -> ()
105+
! CHECK: fir.call @llvm.nvvm.bar.warp.sync(%c1{{.*}}) fastmath<contract> : (i32) -> ()
106106
! CHECK: fir.call @llvm.nvvm.barrier0.and(%c1{{.*}}) fastmath<contract> : (i32) -> i32
107107
! CHECK: fir.call @llvm.nvvm.barrier0.popc(%c1{{.*}}) fastmath<contract> : (i32) -> i32
108108
! CHECK: fir.call @llvm.nvvm.barrier0.or(%c1{{.*}}) fastmath<contract> : (i32) -> i32
109109

110110
! CHECK: func.func private @llvm.nvvm.barrier0()
111-
! CHECK: func.func private @__syncwarp(i32) attributes {cuf.proc_attr = #cuf.cuda_proc<device>, fir.bindc_name = "__syncwarp", fir.proc_attrs = #fir.proc_attrs<bind_c>}
111+
! CHECK: func.func private @llvm.nvvm.bar.warp.sync(i32)
112112
! CHECK: func.func private @llvm.nvvm.membar.gl()
113113
! CHECK: func.func private @llvm.nvvm.membar.cta()
114114
! CHECK: func.func private @llvm.nvvm.membar.sys()

0 commit comments

Comments
 (0)