From b649ef010219336b8cafb8a2b71c39d8865e8a5f Mon Sep 17 00:00:00 2001 From: Zentrik Date: Mon, 14 Aug 2023 09:33:44 +0100 Subject: [PATCH] Add fastmath flag to PTXCompilerTarget (#492) --- src/ptx.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/ptx.jl b/src/ptx.jl index 5a9bc641..c215a15e 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -18,6 +18,8 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget blocks_per_sm::Union{Nothing,Int} = nothing maxregs::Union{Nothing,Int} = nothing + fastmath::Bool = Base.JLOptions().fast_math == 1 + # deprecated; remove with next major version exitable::Union{Nothing,Bool} = nothing unreachable::Union{Nothing,Bool} = nothing @@ -33,6 +35,7 @@ function Base.hash(target::PTXCompilerTarget, h::UInt) h = hash(target.maxthreads, h) h = hash(target.blocks_per_sm, h) h = hash(target.maxregs, h) + h = hash(target.fastmath, h) h end @@ -82,6 +85,7 @@ function Base.show(io::IO, @nospecialize(job::CompilerJob{PTXCompilerTarget})) job.config.target.maxthreads !== nothing && print(io, ", maxthreads=$(job.config.target.maxthreads)") job.config.target.blocks_per_sm !== nothing && print(io, ", blocks_per_sm=$(job.config.target.blocks_per_sm)") job.config.target.maxregs !== nothing && print(io, ", maxregs=$(job.config.target.maxregs)") + job.config.target.fastmath && print(io, ", fast math enabled") end const ptx_intrinsics = ("vprintf", "__assertfail", "malloc", "free") @@ -424,7 +428,7 @@ function nvvm_reflect!(fun::LLVM.Function) # handle possible cases # XXX: put some of these property in the compiler job? # and/or first set the "nvvm-reflect-*" module flag like Clang does? - fast_math = Base.JLOptions().fast_math == 1 + fast_math = current_job.config.target.fastmath # NOTE: we follow nvcc's --use_fast_math reflect_val = if reflect_arg == "__CUDA_FTZ" # single-precision denormals support @@ -433,7 +437,7 @@ function nvvm_reflect!(fun::LLVM.Function) # single-precision floating-point division and reciprocals. ConstantInt(reflect_typ, fast_math ? 0 : 1) elseif reflect_arg == "__CUDA_PREC_SQRT" - # single-precision denormals support + # single-precision floating point square roots. ConstantInt(reflect_typ, fast_math ? 0 : 1) elseif reflect_arg == "__CUDA_FMAD" # contraction of floating-point multiplies and adds/subtracts into